for name, param in model.state_dict().items()例子
时间: 2024-03-05 13:51:43 浏览: 116
以下是一个简单的例子,展示如何使用`for name, param in model.state_dict().items()`来遍历模型的参数:
``` python
import torch
import torch.nn as nn
# 定义一个简单的神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1)
def forward(self, x):
x = self.fc1(x)
x = torch.relu(x)
x = self.fc2(x)
return x
# 创建一个Net类的实例
model = Net()
# 遍历模型的参数,并打印参数的名称和张量大小
for name, param in model.state_dict().items():
print(name, param.size())
```
输出结果为:
```
fc1.weight torch.Size([20, 10])
fc1.bias torch.Size([20])
fc2.weight torch.Size([1, 20])
fc2.bias torch.Size([1])
```
这个例子中,我们创建了一个名为`Net`的简单神经网络模型,并创建了一个`Net`类的实例`model`。使用`for name, param in model.state_dict().items()`遍历了模型的参数,并打印了每个参数的名称和张量大小。
阅读全文