用PyTorch实现线性模型
四个步骤:
- Prepare dataset
- Design model using Class
- Inherit from nn.Module
- Construct loss and optimizer
- Training cycle
- forward, backward, update
Design model using Class
演示所用的线性模型:
$\hat{y}=\omega*x+b$线性模型类的定义:
1 | class LinearModel(torch.nn.Module): |
补充Python的callable知识点:
1 | class Foobar: |
当一个类定义了__call__(self, *args, **kwargs)
函数,这个类的实例(上例中为foobar)就可以直接调用了
nn.Linear
类也执行了__call__()
函数,并在这个函数调用了forward()
函数
所以说用model=LinearModel()
声明了model
之后,就可以用model(x)
来对输入数据进行变换了,LinearModel
的__call__()
函数会调用对应的forward()
函数
Construct loss and optimizer
1 | criterion = torch.nn.MSELoss(size_average=False) |
Training Cycle
前馈,反馈,更新
1 | for epoch in range(100): |
补充:PyTorch中常用的优化器:
1 | torch.optim.Adagrad |
课程来源:《PyTorch深度学习实践》完结合集