nn.Parameter
时间: 2023-10-04 16:12:42 浏览: 95
parameter.m
`nn.Parameter` 是 PyTorch 中的一个类,用于将一个 Tensor 转换成一个可训练的参数,以便在模型训练过程中进行优化。与普通的 Tensor 不同,`nn.Parameter` 对象会自动被注册到模型的参数列表中,从而使其可以被自动更新。同时,`nn.Parameter` 也会继承 Tensor 的所有方法和属性,可以像普通的 Tensor 一样进行计算和操作。例如:
```
import torch.nn as nn
# 创建一个 3x3 的参数矩阵
param = nn.Parameter(torch.randn(3, 3))
# 将参数矩阵用于模型中的线性变换
linear = nn.Linear(3, 2)
linear.weight = param
# 在反向传播时,参数矩阵会自动进行梯度计算和更新
```
阅读全文