Python中nn.Linear
时间: 2023-10-26 10:26:16 浏览: 107
浅析PyTorch中nn.Linear的使用
5星 · 资源好评率100%
nn.Linear是PyTorch(一个流行的深度学习框架)中的一个类,用于定义一个线性变换层。它在神经网络中常用于实现全连接层。
nn.Linear接受两个参数:输入特征的数量和输出特征的数量。例如,如果你想将一个具有10个输入特征和5个输出特征的层添加到你的神经网络中,你可以使用以下代码:
```
import torch
import torch.nn as nn
# 输入特征数量为10,输出特征数量为5
linear_layer = nn.Linear(10, 5)
# 使用线性层进行输入数据的变换
input_data = torch.randn(100, 10) # 生成一个形状为(100, 10)的张量作为输入数据
output = linear_layer(input_data)
```
在上面的示例中,`input_data`是一个形状为(100, 10)的张量,表示100个样本,每个样本具有10个特征。`linear_layer(input_data)`将对输入数据进行线性变换,并返回一个形状为(100, 5)的张量,表示100个样本,每个样本具有5个输出特征。
使用nn.Linear可以方便地定义和使用线性变换层,它会自动管理权重和偏置项,并且可以与其他PyTorch中的层一起构建神经网络模型。
阅读全文