pytorch 保存模型
时间: 2023-09-01 12:11:04 浏览: 86
Pytorch 保存模型生成图片方式
5星 · 资源好评率100%
在 PyTorch 中,我们可以使用 `torch.save` 函数来保存模型。该函数的第一个参数是要保存的对象,通常是一个模型或优化器,第二个参数是文件名(包括路径和扩展名),文件名的扩展名通常是 `.pt` 或 `.pth`。
以下代码演示了如何保存一个简单的神经网络模型:
```python
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
model = Net()
# 保存模型
torch.save(model.state_dict(), 'model.pt')
```
在上面的示例中,我们创建了一个简单的神经网络模型,并使用 `torch.save` 函数将其保存到文件 `model.pt` 中。模型的状态字典可以通过 `model.state_dict()` 方法获取,它包含了所有模型的参数和缓存。
如果要加载已经保存的模型,可以使用 `torch.load` 函数,如下所示:
```python
# 加载模型
model.load_state_dict(torch.load('model.pt'))
```
以上代码将从文件 `model.pt` 中加载模型的状态字典,并将其设置为模型的参数。
阅读全文