nn.linear
时间: 2023-07-13 11:15:51 浏览: 133
Pytorch|神经网络工具箱nn
`nn.Linear` 是 PyTorch 中的一个类,用于定义一个线性转换(即仿射变换),将输入张量的每个元素乘以权重矩阵,加上偏置向量,得到输出张量。其定义如下:
```python
class torch.nn.Linear(in_features: int, out_features: int, bias: bool = True)
```
其中,`in_features` 是输入张量的特征数,即输入张量的最后一个维度的大小;`out_features` 是输出张量的特征数,即输出张量的最后一个维度的大小;`bias` 表示是否使用偏置向量,默认为 `True`。
例如,定义一个输入大小为 10,输出大小为 5 的线性层:
```python
import torch.nn as nn
linear = nn.Linear(10, 5)
```
此时,`linear` 就是一个 `nn.Linear` 的实例,可以使用它来进行线性变换。
阅读全文