nn.Parameter()函数的用法
时间: 2023-10-04 08:08:41 浏览: 91
`nn.Parameter()`函数是PyTorch中用于创建可训练参数的类。它是`torch.Tensor`的子类,但被包装在`nn.Parameter`类中可以自动注册为模型的可训练参数。这意味着在模型的训练过程中,这些参数的值可以被优化器自动更新。
以下是`nn.Parameter()`函数的用法示例:
```python
import torch
import torch.nn as nn
# 创建一个可训练参数
param = nn.Parameter(torch.randn(3, 4))
print(param)
# 输出: Parameter containing:
# tensor([[-0.2193, 0.2330, -1.0280, 1.5679],
# [ 0.8522, -0.3909, -1.0786, -0.1507],
# [ 0.2862, 1.5818, 0.7069, -0.6836]], requires_grad=True)
# 将可训练参数用于模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.randn(5, 5))
def forward(self, x):
return torch.matmul(x, self.weight)
model = MyModel()
print(model)
# 输出: MyModel(
# (weight): Parameter containing: [torch.FloatTensor of size 5x5]
# )
# 访问可训练参数的值
print(model.weight)
# 输出: Parameter containing: [torch.FloatTensor of size 5x5]
# 使用可训练参数进行前向传播
input = torch.randn(2, 5)
output = model(input)
print(output)
```
在上述示例中,我们首先使用`nn.Parameter()`函数创建了一个形状为`(3, 4)`的可训练参数`param`。然后,
阅读全文