请写一段python代码,创建一个简单的CNN神经网络模型并用torchvision中的MNIST数据集对其进行训练
时间: 2023-05-12 17:04:08 浏览: 99
基于Python的卷积神经网络(CNN)识别MNIST数据集.zip
5星 · 资源好评率100%
当然,我可以为您编写一段简单的Python代码,用于创建一个CNN神经网络模型并使用torchvision中的MNIST数据集进行训练。以下是代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义CNN模型
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(7*7*32, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = x.view(-1, 7*7*32)
x = self.fc1(x)
return x
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
# 定义训练参数
batch_size = 64
learning_rate = 0.01
num_epochs = 10
# 加载数据集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.Data
阅读全文