prthon中nn.Linear
时间: 2023-10-14 20:28:45 浏览: 78
浅析PyTorch中nn.Module的使用
5星 · 资源好评率100%
nn.Linear是PyTorch中的一个模块,用于定义一个线性层(fully connected layer)。该模块可以接收输入张量,对其进行线性变换,并输出结果张量。nn.Linear通常用于构建神经网络的全连接层,可以实现输入特征的线性组合,并输出到下一层进行非线性变换。
具体来说,nn.Linear模块的构造函数如下:
```
class torch.nn.Linear(in_features: int, out_features: int, bias: bool = True)
```
其中,in_features表示输入张量的特征数,out_features表示输出张量的特征数,bias表示是否使用偏置项。构造一个nn.Linear对象后,可以通过调用该对象的forward方法来进行线性变换,如下所示:
```
import torch.nn as nn
linear_layer = nn.Linear(10, 20) # 构造一个输入特征数为10,输出特征数为20的线性层
input_tensor = torch.randn(32, 10) # 构造一个大小为32x10的随机输入张量
output_tensor = linear_layer(input_tensor) # 对输入张量进行线性变换
```
上述代码中,我们首先构造了一个输入特征数为10,输出特征数为20的线性层,然后构造了一个大小为32x10的随机输入张量,并通过调用linear_layer的forward方法对其进行线性变换,得到大小为32x20的输出张量。
阅读全文