CNN pytorch 代码
时间: 2024-09-25 21:05:01 浏览: 26
CNN (卷积神经网络) 在 PyTorch 中的实现通常包含以下几个关键步骤:
```python
import torch
import torch.nn as nn
# 定义一个简单的卷积神经网络模型
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
# 卷积层
self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(64) # 归一化层
self.relu = nn.ReLU() # 激活函数
# 池化层
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 全连接层
self.fc1 = nn.Linear(64 * 8 * 8, 128) # 假设输入是8x8的特征图
self.drop = nn.Dropout(p=0.5) # 防止过拟合
self.fc2 = nn.Linear(128, num_classes) # 根据任务调整num_classes
def forward(self, x):
x = self.pool(self.relu(self.bn1(self.conv1(x))))
x = x.view(-1, 64 * 8 * 8) # 展平
x = self.relu(self.fc1(x))
x = self.drop(x)
x = self.fc2(x)
return x
# 创建模型实例并设置随机种子
model = ConvNet()
torch.manual_seed(0)
```
在这个例子中,我们首先导入必要的库,然后创建一个`ConvNet`类,它有卷积层、池化层、全连接层等组件。`forward`方法描述了数据流过网络的方式。在实际应用中,还需要对模型进行训练和评估。