nn.moudle详解
时间: 2023-06-12 12:03:45 浏览: 193
`nn.Module` 是 PyTorch 中用于构建神经网络模型的基类。所有自定义的神经网络模型都应该继承自 `nn.Module` 类。`nn.Module` 类提供了很多有用的方法,使得神经网络模型的构建和训练变得更加简单。
`nn.Module` 类的常用方法:
- `__init__(self)`: 初始化函数,用于定义模型的结构和参数;
- `forward(self, input)`: 前向传播函数,用于定义模型的计算过程;
- `parameters(self)`: 返回模型中所有需要训练的参数(权重和偏置);
- `named_parameters(self)`: 返回模型中所有需要训练的参数及其名称;
- `zero_grad(self)`: 将模型中所有参数的梯度设置为0;
- `to(self, device)`: 将模型移动到指定的设备上(如:CPU或GPU);
- `train(self, mode=True)`: 设置模型为训练模式,当模型中包含一些会改变的操作时(如:Dropout),需要调用该方法;
- `eval(self)`: 设置模型为评估模式,当模型中包含一些不会改变的操作时(如:BatchNorm),需要调用该方法。
下面是一个简单的例子,展示了如何使用 `nn.Module` 类来构建一个神经网络模型:
```python
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = nn.functional.relu(nn.functional.max_pool2d(self.conv1(x), 2))
x = nn.functional.relu(nn.functional.max_pool2d(self.conv2(x), 2))
x = x.view(-1, 320)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return nn.functional.log_softmax(x, dim=1)
```
在这个例子中,我们定义了一个名为 `Net` 的模型,继承自 `nn.Module` 类。在 `__init__` 方法中,我们定义了模型的结构和参数,包括两个卷积层和两个全连接层。在 `forward` 方法中,我们定义了模型的计算过程,包括卷积、池化、全连接和激活操作。我们还可以通过调用其他 `nn.Module` 类中的方法,如 `nn.functional.relu` 和 `nn.functional.log_softmax` 等,来方便地构建模型。
阅读全文