torch.load示例代码
时间: 2023-11-06 19:18:47 浏览: 48
当使用PyTorch中的torch.load函数加载模型或者其他保存的对象时,通常需要提供保存路径作为输入参数。下面是一个示例代码:
```python
import torch
# 定义模型
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear = torch.nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = MyModel()
# 保存模型
save_path = 'model.pth'
torch.save(model.state_dict(), save_path)
# 加载模型
loaded_model = MyModel()
loaded_model.load_state_dict(torch.load(save_path))
```
在上面的示例中,首先我们定义了一个简单的模型`MyModel`,然后使用`torch.save`函数将模型的参数保存到名为`model.pth`的文件中。接下来,我们创建了另一个`MyModel`的实例`loaded_model`,并使用`torch.load`函数加载之前保存的模型参数,并使用`load_state_dict`方法将参数加载到新创建的模型实例中。
请注意,`state_dict`是一个字典对象,其中包含模型的参数和缓冲区。这种方式可以方便地保存和加载模型的状态,同时也适用于其他PyTorch对象的保存和加载。