torch.nn.linear用法
时间: 2023-11-26 12:04:46 浏览: 92
浅析PyTorch中nn.Linear的使用
`torch.nn.Linear` 是 PyTorch 中的一个函数,用于定义一个全连接层,它的用法如下:
```python
torch.nn.Linear(in_features: int, out_features: int, bias: bool = True)
```
其中:
- `in_features`:输入特征的数量。
- `out_features`:输出特征的数量。
- `bias`:是否使用偏置项,默认为 True。
`torch.nn.Linear` 的作用是将输入特征线性变换为输出特征,即 $y = xA^T + b$,其中 $x$ 是输入特征,$y$ 是输出特征,$A$ 是权重矩阵,$b$ 是偏置向量。
使用示例:
```python
import torch.nn as nn
# 定义一个全连接层,输入特征为10,输出特征为5
linear = nn.Linear(10, 5)
# 输入特征是一个大小为[batch_size, 10]的张量
input_tensor = torch.randn(3, 10)
# 调用全连接层进行计算
output_tensor = linear(input_tensor)
# 输出特征是一个大小为[batch_size, 5]的张量
print(output_tensor.size())
```
阅读全文