nn.Sequential是干什么的
时间: 2024-04-28 11:16:31 浏览: 86
nn.Sequential是一个PyTorch中的容器,用于按照顺序添加和组合多个神经网络模块。它可以自动完成前向传播,并且非常方便地构建深度神经网络。
具体来说,nn.Sequential可以通过传入一系列的神经网络模块来构建一个神经网络。这些模块将按照传入的顺序依次添加到Sequential容器中。当进行前向传播时,输入数据会依次经过每个模块,并且每个模块的输出将作为下一个模块的输入,最终得到最终的输出结果。
下面是一个示例代码,展示了如何使用nn.Sequential构建一个简单的神经网络:
```python
import torch
import torch.nn as nn
# 定义一个简单的神经网络模块
class BasicNet(nn.Module):
def __init__(self):
super(BasicNet, self).__init__()
self.fc1 = nn.Linear(3, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# 使用nn.Sequential构建神经网络
net = nn.Sequential(
BasicNet(),
nn.ReLU(),
nn.Linear(2, 1)
)
# 进行前向传播
input = torch.randn(10, 3)
output = net(input)
print(output)
```
在上面的代码中,我们首先定义了一个简单的神经网络模块BasicNet,然后使用nn.Sequential将它和其他模块按照顺序添加到Sequential容器中。最后,我们使用输入数据进行前向传播,并输出最终的结果。
阅读全文