torch.fx.graph_module.GraphModule.load_state_dict()的用法示例
时间: 2024-05-04 09:20:12 浏览: 361
`torch.fx.graph_module.GraphModule.load_state_dict()` 方法可以用于加载模型参数。以下是一个示例:
```python
import torch
from torch.fx.graph_module import GraphModule
# 定义一个简单的模型
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = torch.nn.Linear(10, 5)
self.relu1 = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(5, 2)
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, x):
x = self.fc1(x)
x = self.relu1(x)
x = self.fc2(x)
x = self.softmax(x)
return x
# 创建一个实例并保存模型参数
model = SimpleModel()
torch.save(model.state_dict(), 'model.pt')
# 创建一个 GraphModule 实例
gm = GraphModule(model, torch.randn(1, 10))
# 加载模型参数
gm.load_state_dict(torch.load('model.pt'))
# 打印模型参数
for name, param in gm.named_parameters():
print(name, param)
```
在上面的示例中,我们创建了一个简单的模型 `SimpleModel`,并将其保存到 `model.pt` 文件中。然后,我们创建了一个 `GraphModule` 实例 `gm`,并使用 `load_state_dict()` 方法加载了模型参数。最后,我们可以使用 `named_parameters()` 方法打印模型参数。
阅读全文