torch.nn.linear函数具体使用案例
时间: 2023-06-12 08:06:31 浏览: 128
`torch.nn.Linear` 是 PyTorch 中的一个线性变换模块,对输入数据进行线性变换,即 Y = XW^T + b,其中 X 为输入,W 为权重矩阵,b 为偏置向量,Y 为输出。
具体使用案例如下:
```python
import torch
import torch.nn as nn
# 定义输入数据维度为 2,输出数据维度为 3
input_dim = 2
output_dim = 3
# 定义一个 Linear 层
linear_layer = nn.Linear(input_dim, output_dim)
# 定义输入数据
input_data = torch.randn(4, input_dim)
# 进行线性变换
output_data = linear_layer(input_data)
print("输入数据:\n", input_data)
print("输出数据:\n", output_data)
```
输出结果如下:
```
输入数据:
tensor([[-1.5819, -0.1541],
[-1.1012, -1.1826],
[-0.2129, -0.3476],
[ 0.5284, -1.5889]])
输出数据:
tensor([[ 0.2243, -1.1177, -0.1403],
[ 0.5021, -0.6952, -0.8592],
[ 0.5135, -0.3315, -0.0834],
[ 0.5424, -0.3735, -0.0146]], grad_fn=<AddmmBackward>)
```
可以看到,输入数据的维度为 4x2,输出数据的维度为 4x3,这是因为输入数据经过线性变换后,每个样本的特征维度从 2 变成了 3。