cnn神经网络模型代码
时间: 2023-07-30 18:06:47 浏览: 50
以下是一个简单的 CNN 神经网络模型代码,使用 PyTorch 框架实现:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5)
self.fc1 = nn.Linear(in_features=32*4*4, out_features=128)
self.fc2 = nn.Linear(in_features=128, out_features=10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, kernel_size=2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, kernel_size=2)
x = x.view(-1, 32*4*4)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
```
这个模型包括两个卷积层和两个全连接层,用于对 MNIST 数据集进行图像分类。在 forward 方法中,输入的 x 是一个 28x28 的灰度图像,通过卷积层和池化层进行特征提取,最终输出一个 10 维的向量,表示输入图像属于 10 个数字中的哪一个。