用python写一个LSTM模型pytorch基于minist数据集完成手写字符识别
时间: 2024-05-11 15:16:54 浏览: 145
好的,我可以帮你回答这个问题。以下是用Python编写LSTM模型并使用PyTorch在MNIST数据集上完成手写字符识别的代码示例:
```
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
# 定义模型
class LSTMModel(nn.Module):
def __init__(self):
super(LSTMModel, self).__init__()
self.lstm = nn.LSTM(28, 64, 2, batch_first=True)
self.fc = nn.Linear(64, 10)
def forward(self, x):
out, _ = self.lstm(x)
out = self.fc(out[:, -1, :])
return out
model = LSTMModel()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 定义训练函数
def train(model, train_dataloader, optimizer, criterion):
model.train()
for i, data in enumerate(train_dataloader):
x, y = data
x, y = Variable(x), Variable(y)
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
# 定义测试函数
def test(model, test_dataloader, criterion):
model.eval()
total_loss = 0.0
correct = 0
for i, data in enumerate(test_dataloader):
x, y = data
x, y = Variable(x), Variable(y)
output = model(x)
loss = criterion(output, y)
total_loss += loss.data[0]
_, predicted = torch.max(output.data, 1)
correct += (predicted == y.data).sum()
avg_loss = total_loss / len(test_dataset)
accuracy = float(correct) / len(test_dataset)
return avg_loss, accuracy
# 开始训练模型
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)
for epoch in range(10):
train(model, train_dataloader, optimizer, criterion)
avg_loss, accuracy = test(model, test_dataloader, criterion)
print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch+1, 10, avg_loss, accuracy*100))
```
这个代码会输出每个epoch的平均损失和准确率。你可以根据需要修改代码中的一些参数,如学习率,训练轮数等。
阅读全文