写一个pytorch代码构建CNN
时间: 2024-10-04 16:04:30 浏览: 13
当然可以。在PyTorch中,我们可以创建一个简单的卷积神经网络(Convolutional Neural Network,简称CNN)用于图像分类任务。这里是一个基础版本的示例,它包括一个基本的卷积层、池化层以及全连接层:
```python
import torch
import torch.nn as nn
import torch.optim as optim
class SimpleCNN(nn.Module):
def __init__(self, input_channels, num_classes):
super(SimpleCNN, self).__init__()
# 卷积层
self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# 全连接层
self.fc1 = nn.Linear(16 * 8 * 8, 64) # 假设输入图片大小为8x8经过前面的处理后变成16x8x8
self.fc2 = nn.Linear(64, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(-1, 16 * 8 * 8) # 扁平化以便于全连接层处理
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 示例用法
input_channels = 3 # RGB 图像通常有3通道
num_classes = 10 # 类别数假设为10 (例如MNIST数据集)
model = SimpleCNN(input_channels, num_classes)
# 初始化优化器和损失函数
optimizer = optim.SGD(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
```