最简单的cnn神经网络Python代码
时间: 2024-10-14 08:13:43 浏览: 21
在PyTorch中,最简单的卷积神经网络(CNN)结构通常包括几个基本组件:输入层、卷积层、池化层和全连接层。以下是一个非常基础的CNN模型构建示例:
```python
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self, input_channels=3, num_classes=10):
super(SimpleCNN, self).__init__()
# 卷积层1
self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
# 池化层1
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
# 卷积层2
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
# 池化层2
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
# 全连接层
self.fc = nn.Linear(32 * 8 * 8, num_classes) # 假设输入图像大小为32x32,计算出的维度
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
# 将二维张量展平成一维向量以便于全连接层处理
x = x.view(-1, 32 * 8 * 8)
x = self.fc(x)
return x
# 使用模型
model = SimpleCNN()
# 假设x_data是一个Tensor,形状为(batch_size, input_channels, image_height, image_width)
output = model(x_data)
阅读全文