RNN做cifar-10图像分类任务pytorch版本效果好且画出loss变化曲线以及准确率变化曲线的代码
时间: 2023-12-25 16:05:27 浏览: 83
RNN不太适合图像分类任务,因为图像数据的维度较高,需要大量的计算资源和时间。但是,可以通过将图像数据展平成一维向量,然后将其输入到RNN中进行分类。以下是使用PyTorch实现一个基于RNN的图像分类器来处理CIFAR-10数据集的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 图像预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
# 构建模型
class RNNModel(nn.Module):
def __init__(self):
super(RNNModel, self).__init__()
self.rnn = nn.RNN(input_size=3072, hidden_size=512, num_layers=2, batch_first=True)
self.fc = nn.Linear(512, 10)
def forward(self, x):
# 将输入数据reshape成(batch_size, seq_len, feature_dim)
x = x.view(-1, 3072, 1).transpose(1, 2)
x, _ = self.rnn(x)
x = x[:, -1, :]
x = self.fc(x)
return x
net = RNNModel()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
# 训练模型
loss_list = []
acc_list = []
for epoch in range(10): # 多批次循环
running_loss = 0.0
correct = 0
total = 0
for i, data in enumerate(trainloader, 0):
# 获取输入
inputs, labels = data
# 梯度清零
optimizer.zero_grad()
# 前向传播,反向传播,优化
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 打印统计信息
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
acc_list.append(acc)
loss_list.append(running_loss / len(trainloader))
print('[%d] loss: %.3f, acc: %.3f' %
(epoch + 1, running_loss / len(trainloader), acc))
print('Finished Training')
# 绘制loss变化曲线和准确率变化曲线
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2, 1, figsize=(10, 10))
axs[0].plot(loss_list)
axs[0].set_title("Training Loss")
axs[0].set_xlabel("Epoch")
axs[0].set_ylabel("Loss")
axs[1].plot(acc_list)
axs[1].set_title("Training Accuracy")
axs[1].set_xlabel("Epoch")
axs[1].set_ylabel("Accuracy")
plt.show()
# 测试模型
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
```
在这个例子中,我们将图像数据reshape成一维向量,然后通过一个RNN模型进行分类。模型由两层RNN和一个全连接层组成。在训练过程中,我们记录每个epoch的loss和准确率,并将它们绘制成曲线。最后,在测试集上评估模型的准确率。
可以使用Matplotlib库绘制loss和准确率的变化曲线,如上所示。
阅读全文