nn.linear 输入特征大小怎么计算
时间: 2023-05-26 22:07:03 浏览: 49
对于 nn.Linear 模块,输入特征大小指的是每个样本的特征维度大小,通常也被称为输入层的大小。如果输入数据是一个 $(N, D_{in})$ 的矩阵,其中 $N$ 表示样本数,$D_{in}$ 表示每个样本的特征维度大小,则输入特征大小为 $D_{in}$。
在实际使用中,可以通过打印输入数据的 shape 属性获取输入特征大小。例如:
```
import torch
batch_size, input_size = 32, 784
x = torch.randn(batch_size, input_size)
print(x.shape) # (32, 784)
```
其中,batch_size 表示样本数量,input_size 表示每个样本的特征维度大小。在 nn.Linear 模块中定义时,可以使用 input_size 作为输入特征维度大小。例如:
```
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, 1)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
model = MyModel(input_size=784, hidden_size=256)
```
在上面的例子中,MyModel 的输入特征维度大小为 784。通过 nn.Linear 模块定义了两个全连接层,分别为输入层和隐藏层,其中输入层的输入特征维度大小为 784,隐藏层的输出特征维度大小为 256。