pytorch 保存模型参数
时间: 2023-09-27 08:02:44 浏览: 121
在PyTorch中,我们可以使用torch.save()
函数来保存模型的参数。该函数的第一个参数是要保存的模型参数,可以是模型的state_dict或完整的模型对象。第二个参数是保存的文件路径。
下面是一个保存模型参数的简单例子:
import torch
import torch.nn as nn
# 创建一个简单的模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
model = Net()
# 保存模型参数
torch.save(model.state_dict(), 'model.pth')
在上述代码中,我们首先定义了一个简单的神经网络模型Net
,包含一个全连接层。然后,我们创建了一个实例化的模型对象model
。
最后,我们使用torch.save()
函数将模型的state_dict保存到文件model.pth
中。
要加载保存的模型参数,可以使用torch.load()
函数。下面是加载模型参数的例子:
# 加载模型参数
model = Net()
model.load_state_dict(torch.load('model.pth'))
在上述代码中,我们首先创建了一个新的模型对象model
,然后使用torch.load()
函数加载之前保存的模型参数。通过这种方式,我们可以恢复训练过的模型或在其他任务中使用保存的模型参数。
相关推荐


















