什么是nn.Module
时间: 2023-10-23 09:10:14 浏览: 159
pytorch 中的重要模块化接口nn.Module的使用
`nn.Module` 是 PyTorch 中的一个基础类,用于构建神经网络。在 PyTorch 中,神经网络模型的构建和训练都是通过组合 `nn.Module` 的子类来完成的。
`nn.Module` 为神经网络提供了许多有用的功能,包括:
- 自动追踪模型中的参数;
- 可以使用 `.cuda()` 方法将模型移动到 GPU 上进行训练;
- 基于反向传播自动计算梯度。
当我们定义神经网络时,通常需要实现 `nn.Module` 的子类,并在其中定义网络的结构和前向传播方法。例如:
```python
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.fc = nn.Linear(32 * 8 * 8, 10)
def forward(self, x):
x = self.conv1(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = self.conv2(x)
x = nn.functional.relu(x)
x = nn.functional.max_pool2d(x, 2)
x = x.view(-1, 32 * 8 * 8)
x = self.fc(x)
return x
```
在这个例子中,我们定义了一个名为 `MyModel` 的子类,并在其构造函数中定义了三个层(两个卷积层和一个全连接层)。在 `forward` 方法中,我们定义了网络的前向传播逻辑。这个模型可以接受一个输入张量 `x`,并返回一个输出张量。整个模型的结构就是输入 x 经过一系列的卷积层、激活函数、池化层和全连接层后得到的输出。
在实际使用中,我们可以通过创建一个 `MyModel` 的实例来构建一个神经网络,并调用它的 `forward` 方法来进行前向传播计算。例如:
```python
model = MyModel()
input_tensor = torch.randn(1, 3, 32, 32)
output_tensor = model(input_tensor)
```
阅读全文