nn.ParameterList
时间: 2023-10-05 13:07:34 浏览: 204
`nn.ParameterList`是PyTorch中的一个类,用于管理模型参数的列表。它是`nn.Module`的子类,可以作为`nn.Module`的属性使用来管理一组模型参数。
下面是一个使用`nn.ParameterList`的示例:
```python
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(3, 4)), nn.Parameter(torch.randn(5, 2))])
def forward(self, x):
# 使用参数
return x
model = MyModel()
print(model.params)
```
在这个示例中,我们创建了一个自定义的模型`MyModel`,其中包含了一个`nn.ParameterList`属性`params`,该属性初始化为包含两个参数的列表。这里每个参数都是使用`nn.Parameter`函数创建的随机张量。在模型的前向传播方法中,我们可以使用这些参数来进行计算。
当我们创建模型实例`model`并打印`model.params`时,将会输出模型的参数列表:
```
ParameterList(
(0): Parameter containing: tensor([[ 0.1913, -0.2123, 0.3911, -0.2667],
[-1.0693, -0.1225, -0.4909, 0.1232], [-0.2306, 0.6144, -1.8813, 0.0137]], requires_grad=True)
(1): Parameter containing: tensor([[ 1.3175, -1.3870],
[ 0.0423, -0.3143],
[-1.6024, -0.9305],
[-1.4996, 0.3935],
[ 1.1657, 2.2699]], requires_grad=True)
)
```
可以看到,`model.params`包含了两个参数,每个参数都是一个`nn.Parameter`对象。
阅读全文