介绍pytorch中的nn.Module类
时间: 2024-02-22 09:28:37 浏览: 85
在PyTorch中,`nn.Module`是所有神经网络模块的基类。它是一个封装了参数、计算方法以及其他网络组件的类,可以用来构建自己的神经网络模型。
每个`nn.Module`子类的构造函数中都应该调用基类的构造函数。在`__init__`方法中,我们可以定义网络中的各个层、参数和其他组件。我们也可以在`forward`方法中定义网络的前向传播过程,即输入数据经过一系列计算后得到输出结果。
`nn.Module`提供了很多实用的方法,例如`parameters`方法可以返回模型中所有可训练的参数,`to`方法可以将模型转移到指定的设备上等。
示例代码:
```python
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(16 * 14 * 14, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
model = MyModel()
input = torch.randn(1, 3, 28, 28)
output = model(input)
```
这里我们定义了一个简单的卷积神经网络模型,包括了一个卷积层、一个ReLU激活函数、一个最大池化层和一个全连接层。在`forward`方法中,我们定义了输入数据的前向传播过程。我们可以通过调用`parameters`方法打印出模型中的所有参数:`print(list(model.parameters()))`。我们还可以使用`to`方法将模型转移到GPU上:`model.to(device)`。
阅读全文