利用pytorch搭建CNN
时间: 2023-08-09 16:06:56 浏览: 60
当利用PyTorch搭建卷积神经网络(CNN)时,需要使用`torch.nn`模块中的类来定义网络的结构和操作。以下是一个简单的示例,展示如何使用PyTorch搭建一个包含卷积层、池化层和全连接层的CNN。
```python
import torch
import torch.nn as nn
# 定义CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 卷积层
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)
# 池化层
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 全连接层
self.fc1 = nn.Linear(16 * 14 * 14, 10) # 输入尺寸:16个通道,14x14的特征图
def forward(self, x):
# 前向传播
x = self.conv1(x)
x = nn.ReLU()(x)
x = self.pool(x)
x = x.view(x.size(0), -1) # 展平特征图
x = self.fc1(x)
return x
# 创建CNN模型实例
model = CNN()
# 打印模型结构
print(model)
```
在上述代码中,我们定义了一个名为`CNN`的子类,继承自`nn.Module`。在`__init__`方法中,我们定义了卷积层`self.conv1`、池化层`self.pool`和全连接层`self.fc1`。在`forward`方法中,我们定义了模型的前向传播过程,通过调用卷积层、激活函数、池化层和全连接层来构建整个网络结构。
通过创建`CNN`类的实例,我们可以得到一个具有卷积层、池化层和全连接层的CNN模型。最后,我们使用`print(model)`打印模型的结构。
这只是一个简单的CNN模型示例,你可以根据具体任务的需求,自定义和添加更多的卷积层、池化层和全连接层来构建更复杂的网络结构。