pytorch之中parameters()是什么,举例子用代码详细说明
时间: 2023-04-06 17:03:28 浏览: 62
parameters()是PyTorch中的一个函数,它可以返回一个包含模型所有可训练参数的迭代器。这些可训练参数包括权重和偏置等。
以下是一个简单的例子,演示如何使用parameters()函数来获取模型的所有可训练参数:
```
import torch.nn as nn
# 定义一个简单的神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
# 创建一个模型实例
net = Net()
# 使用parameters()函数获取模型的所有可训练参数
params = list(net.parameters())
# 打印模型的所有可训练参数
for i, param in enumerate(params):
print("参数{}:{}".format(i, param))
```
输出结果如下:
```
参数0:Parameter containing:
tensor([[ 0.0335, -0.1465, -0.1315, 0.2567, -0.2832, -0.1862, -0.0729, -0.2417,
-0.2210, -0.1848, -0.0745, -0.0974, -0.0806, -0.1004, -0.1296, -0.0627,
-0.1077, 0.1072, -0.1092, -0.1576]], requires_grad=True)
参数1:Parameter containing:
tensor([ 0.2319, -0.0396, -0.1127, -0.0413, -0.2663, -0.0224, -0.1917, -0.2032,
-0.1276, -0.2461, -0.2049, -0.2017, -0.0805, -0.2238, -0.2785, -0.2395,
-0.2463, -0.0193, -0.1848, -0.2627], requires_grad=True)
参数2:Parameter containing:
tensor([[ 0.1605, -0.1845, -0.0421, -0.1018, 0.1149, -0.1180, -0.1807, -0.1747,
-0.0638, -0.1076, -0.0404, 0.0487, -0.0632, -0.1442, -0.1662, -0.1377,
-0.1936, -0.1033, -0.1296, -0.1637]], requires_grad=True)
参数3:Parameter containing:
tensor([-0.0086], requires_grad=True)
```
可以看到,该模型的所有可训练参数都被成功地获取到了。