Transformers做cifar-10图像分类任务pytorch版本效果好且有画图的代码
时间: 2024-01-17 11:04:37 浏览: 133
可以使用PyTorch实现一个基于Transformer的图像分类器来处理CIFAR-10数据集。以下是代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# 图像预处理
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
# 构建模型
class TransformerModel(nn.Module):
def __init__(self):
super(TransformerModel, self).__init__()
self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
self.fc = nn.Linear(256, 10)
def forward(self, x):
# 将输入数据reshape成(seq_len, batch_size, feature_dim)
x = x.permute(0, 2, 3, 1).reshape(-1, x.shape[1] * x.shape[2], x.shape[3])
x = self.transformer_encoder(x)
x = x.mean(dim=1)
x = self.fc(x)
return x
net = TransformerModel()
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)
# 训练模型
for epoch in range(10): # 多批次循环
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入
inputs, labels = data
# 梯度清零
optimizer.zero_grad()
# 前向传播,反向传播,优化
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 打印统计信息
running_loss += loss.item()
if i % 100 == 99: # 每100批次打印一次
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 100))
running_loss = 0.0
print('Finished Training')
# 测试模型
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
```
对于图像分类任务,我们需要将Transformer模型应用于图像的特征提取。这里我们使用了一个6层的TransformerEncoder,每一层由8个头的自注意力机制和前馈神经网络组成。最后,通过一个全连接层将提取的特征映射到10个类别的概率值。在训练过程中,使用交叉熵损失函数和Adam优化器。最后,我们在测试集上评估模型的准确率。
如果您想要可视化模型在CIFAR-10数据集上的分类结果,可以使用以下代码来生成一些样本图像和它们对应的预测结果:
```python
import matplotlib.pyplot as plt
# 获取一些测试图像
dataiter = iter(testloader)
images, labels = dataiter.next()
# 预测图像类别
outputs = net(images)
_, predicted = torch.max(outputs, 1)
# 输出图像和预测结果
fig = plt.figure(figsize=(25, 4))
for idx in range(20):
ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])
ax.imshow(images[idx].permute(1, 2, 0))
ax.set_title("{} ({})".format(str(predicted[idx].item()), str(labels[idx].item())), color=("green" if predicted[idx]==labels[idx] else "red"))
```
这里我们展示了20张图像和它们对应的预测结果。预测正确的标签显示为绿色,预测错误的标签显示为红色。
阅读全文