请给我一个代码例子,是关于pytorch中怎么保存深度学习训练后的模型
时间: 2024-03-27 15:41:32 浏览: 63
PyTorch Lightning:简化深度学习训练的革命
当我们训练好一个深度学习模型后,我们需要将其保存到磁盘上,以便在以后的时间内使用它。在 PyTorch 中,我们可以使用 `torch.save` 函数来保存训练后的模型。下面是一个保存模型的例子:
```python
import torch
# 定义模型
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
# 创建模型和优化器
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
for epoch in range(10):
# 假设我们有一些数据 x 和目标 y
x = torch.randn(16, 10)
y = torch.randn(16, 1)
# 前向传播和计算损失
y_pred = model(x)
loss = torch.nn.functional.mse_loss(y_pred, y)
# 反向传播和更新梯度
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 保存模型
torch.save(model.state_dict(), 'model.pt')
```
在上述代码中,我们首先定义了一个模型 `Model`,并创建了一个优化器 `optimizer`。然后,我们使用一些数据来训练模型,其中我们使用了均方误差损失函数。最后,我们使用 `torch.save` 函数来保存模型的状态字典(state_dict)到名为 `model.pt` 的文件中。这个状态字典包含了模型的所有参数和缓存项。
阅读全文