用python写CNN模型输入是28x28x1图片,输出15类
时间: 2024-03-26 10:40:51 浏览: 56
以下是一个简单的CNN模型,输入是28x28x1的灰度图像,输出15类。需要注意的是,这只是一个简单的模型,如果需要更高的准确率,需要对模型进行更多的调整和优化。
```python
import torch.nn as nn
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1, 1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 15)
def forward(self, x):
x = self.pool(nn.functional.relu(self.conv1(x)))
x = self.pool(nn.functional.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = nn.functional.relu(self.fc1(x))
x = self.fc2(x)
return x
```
模型中包含两个卷积层和两个全连接层,其中第一个卷积层的输入通道数为1,输出通道数为32,卷积核大小为3x3,步长为1,填充为1,第二个卷积层的输入通道数为32,输出通道数为64,卷积核大小为3x3,步长为1,填充为1。两个卷积层后接最大池化层,池化核大小为2x2,步长为2。两个全连接层的神经元数分别为128和15。最后的输出层不需要激活函数,因为在损失函数中使用的交叉熵损失已经包含了softmax操作。
阅读全文