nn.linear函数
时间: 2023-06-14 09:03:47 浏览: 83
Pytorch|神经网络工具箱nn
`nn.Linear` 是 PyTorch 中的一个类,用于定义全连接层(fully connected layer)。它接收两个参数:输入的特征数和输出的特征数,自动创建权重和偏差参数,并在前向传递时应用它们。
具体来说,`nn.Linear` 的前向传递可以表示为以下公式:
$$y = xA^T + b$$
其中,$x$ 是输入的特征向量,$A$ 是权重矩阵,$b$ 是偏差向量,$y$ 是输出的特征向量。
在 PyTorch 中使用 `nn.Linear` 的示例代码如下:
```python
import torch.nn as nn
# 定义一个全连接层,输入特征数为10,输出特征数为5
fc = nn.Linear(10, 5)
# 假设输入特征向量 x 的形状为 (batch_size, 10)
x = torch.randn(64, 10)
# 前向传递
y = fc(x)
# 输出特征向量 y 的形状为 (batch_size, 5)
```
阅读全文