nn.Linear定义
时间: 2024-01-21 17:13:57 浏览: 108
nn.Linear是PyTorch中的一个类,用于定义线性变换层。它将输入的特征进行线性变换,并输出一个新的特征表示。nn.Linear的定义如下:
```python
class nn.Linear(in_features, out_features, bias=True)
```
其中,参数说明如下:
- in_features:输入特征的大小,即输入的维度。
- out_features:输出特征的大小,即输出的维度。
- bias:是否使用偏置项,默认为True。
nn.Linear的作用是将输入的特征进行线性变换,即计算输入特征与权重矩阵的乘积,并加上偏置项。它可以用于构建神经网络的全连接层。
下面是一个使用nn.Linear的例子:
```python
import torch
import torch.nn as nn
# 定义输入特征的大小和输出特征的大小
in_features = 10
out_features = 5
# 创建一个nn.Linear对象
linear = nn.Linear(in_features, out_features)
# 定义输入数据
input_data = torch.randn(2, in_features)
# 进行线性变换
output = linear(input_data)
# 输出结果
print(output)
```
在这个例子中,我们首先创建了一个nn.Linear对象,然后定义了输入特征的大小和输出特征的大小。接着,我们创建了一个输入数据的张量,并将其传递给nn.Linear进行线性变换。最后,我们打印输出结果。
阅读全文