nn.parameter
时间: 2023-03-24 13:00:56 浏览: 131
`nn.Parameter`是PyTorch中的一个类,用于定义可学习参数(learnable parameters)。在神经网络中,我们通常需要优化一些参数,例如权重和偏置(weights and biases)。`nn.Parameter`允许我们将这些参数包装成可训练的对象,以便能够自动地在反向传播(backpropagation)期间更新它们的值。
下面是一个简单的示例,说明如何使用`nn.Parameter`创建可学习的参数:
```
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.randn(3, 4))
self.bias = nn.Parameter(torch.randn(3))
def forward(self, x):
out = torch.matmul(x, self.weight.t()) + self.bias
return out
```
在这个示例中,我们创建了一个名为`MyModel`的简单神经网络模型。该模型包含一个权重矩阵和一个偏置向量,它们都被声明为`nn.Parameter`类型。在`__init__`方法中,我们使用`torch.randn`函数初始化这些参数。
在`forward`方法中,我们使用权重矩阵和偏置向量来计算输入张量`x`的输出。由于我们使用了`nn.Parameter`来声明这些参数,PyTorch将自动追踪这些参数的计算图,并在反向传播期间更新它们的值。
值得注意的是,`nn.Parameter`实际上只是`torch.Tensor`的一个子类,因此我们可以像使用普通张量一样使用它们。但是,当我们使用`nn.Parameter`时,PyTorch将自动将这些张量标记为需要进行梯度计算(gradient computation)的参数,以便能够自动地更新它们的值。
阅读全文