pytorch识别手写数字代码
时间: 2023-02-28 07:22:24 浏览: 70
我可以提供一个示例代码,它可以帮助您识别手写数字:import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms# Download and load the data
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transforms.ToTensor())
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transforms.ToTensor())# Create a model
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)# Define the loss and the optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# Train the model
num_epochs = 10
for epoch in range(num_epochs):
train_loss = 0.0
for images, labels in train_dataset:
optimizer.zero_grad()
outputs = model(images.view(images.shape[0], -1))
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
print('Epoch [{}/{}], Train Loss: {:.4f}'.format(epoch+1, num_epochs, train_loss/len(train_dataset)))# Test the model
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_dataset:
outputs = model(images.view(images.shape[0], -1))
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
相关推荐













