nn.Linear怎么使用
时间: 2023-10-14 12:21:44 浏览: 207
nn.Linear是PyTorch中的一个类,用于实现线性变换(也称为全连接层)操作。可以使用如下步骤进行使用:
1. 导入PyTorch和nn模块:
```python
import torch
import torch.nn as nn
```
2. 创建一个nn.Linear对象,指定输入特征和输出特征的数量:
```python
linear_layer = nn.Linear(in_features, out_features)
```
其中,in_features是输入特征的数量,out_features是输出特征的数量。
3. 将数据传递给线性层进行计算:
```python
output = linear_layer(inputs)
```
其中,inputs是一个张量,其形状应该是(batch_size,in_features)。
完整的代码示例:
```python
import torch
import torch.nn as nn
# 定义输入特征和输出特征的数量
in_features = 10
out_features = 5
# 创建线性层
linear_layer = nn.Linear(in_features, out_features)
# 创建输入数据
inputs = torch.randn(2, in_features)
# 进行计算
output = linear_layer(inputs)
# 输出结果
print(output)
```
输出结果:
```
tensor([[-0.3069, 0.3474, 0.0154, 0.3210, 0.0042],
[-1.2913, -0.1960, -0.2553, -0.1433, -0.5273]], grad_fn=<AddmmBackward>)
```
这里创建了一个输入特征数量为10,输出特征数量为5的线性层,输入了一个形状为(2,10)的张量,输出了一个形状为(2,5)的张量。
阅读全文