transformers做cifar10图像分类任务训练和测试最后画出曲线图并保存模型
时间: 2023-12-30 08:03:15 浏览: 124
好的,针对您的问题,我将为您提供一份基于 Transformers 实现 CIFAR-10 图像分类任务的 Python 代码。请注意,由于训练时间较长,我将仅提供代码框架,您可以按照注释中的提示进行修改和完善。
```python
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTModel, ViTForImageClassification, ViTConfig, AdamW, get_linear_schedule_with_warmup
# 设置随机种子
torch.manual_seed(42)
# 定义超参数
BATCH_SIZE = 32
LEARNING_RATE = 2e-5
EPOCHS = 10
WARMUP_STEPS = 1000
# 加载 CIFAR-10 数据集
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# 加载 ViT 模型
config = ViTConfig(image_size=224, num_classes=10, hidden_size=768, num_hidden_layers=12, num_attention_heads=12)
model = ViTForImageClassification(config)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
# 设置学习率调度器
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_steps)
# 训练模型
for epoch in range(EPOCHS):
model.train()
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs.logits, labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
if i % 100 == 0:
print(f'Epoch [{epoch+1}/{EPOCHS}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
# 测试模型
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.logits.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')
# 保存模型
torch.save(model.state_dict(), 'vit_cifar10.pth')
```
在上述代码中,我们首先使用 PyTorch 内置的 `datasets` 和 `transforms` 模块加载并预处理 CIFAR-10 数据集;然后使用 Transformers 提供的 `ViTForImageClassification` 类定义 ViT 模型,并使用 AdamW 优化器和学习率调度器进行训练;最后在测试集上评估模型,并使用 `torch.save()` 方法保存模型参数。
您可以根据自己的需要对代码进行修改和完善,例如调整超参数、使用其他预训练模型等。在训练完成后,您可以使用 Matplotlib 绘制训练和测试曲线图,例如:
```python
import matplotlib.pyplot as plt
# 绘制训练和测试曲线图
plt.plot(train_losses, label='Training loss')
plt.plot(test_losses, label='Testing loss')
plt.legend()
plt.show()
```
其中 `train_losses` 和 `test_losses` 分别为每个 epoch 的训练损失和测试损失。
阅读全文
相关推荐


















