保存模型的函数应该写在哪儿?给出具体的示例
时间: 2024-02-29 07:52:32 浏览: 22
保存模型的函数可以写在训练脚本的末尾或者单独写成一个保存模型的脚本。以下是一个保存模型的示例代码:
```python
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(10, 1)
def forward(self, x):
x = self.linear(x)
return x
# 定义数据和优化器
data = torch.randn(100, 10)
target = torch.randn(100, 1)
optimizer = torch.optim.SGD(MyModel().parameters(), lr=0.1)
# 定义损失函数
criterion = torch.nn.MSELoss()
# 训练模型
model = MyModel()
for epoch in range(10):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
# 保存模型
torch.save(model.state_dict(), 'model.pth')
```
上述代码中,我们定义了一个简单的神经网络模型MyModel,并使用随机数据进行训练。训练完成后,我们使用`torch.save()`函数将模型的参数保存到文件`model.pth`中。这个文件包含了模型的所有参数,可以在之后重新加载模型并进行预测。