transformers做cifar10图像分类有loss曲线和准确率曲线并保存模型pytorch
时间: 2024-02-11 14:05:08 浏览: 172
对于使用 transformers 做 cifar10 图像分类的任务,可以使用 PyTorch 框架来实现。以下是一个简单的代码示例,包括训练模型、绘制 loss 和准确率曲线以及保存模型的过程:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from transformers import ViTModel, ViTConfig
# 定义模型
class ViTClassifier(nn.Module):
def __init__(self):
super(ViTClassifier, self).__init__()
config = ViTConfig(image_size=32, num_classes=10)
self.vit = ViTModel(config)
self.classifier = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
x = self.vit(x)
x = self.classifier(x.last_hidden_state[:, 0])
return x
# 数据预处理
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# 开始训练
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ViTClassifier().to(device)
num_epochs = 50
loss_list = []
acc_list = []
for epoch in range(num_epochs):
running_loss = 0.0
total = 0
correct = 0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss_list.append(running_loss / len(trainloader))
acc_list.append(correct / total)
print('[Epoch %d] Loss: %.3f | Acc: %.3f' % (epoch + 1, running_loss / len(trainloader), correct / total))
# 绘制 loss 和准确率曲线
import matplotlib.pyplot as plt
plt.plot(loss_list)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
plt.plot(acc_list)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.show()
# 保存模型
torch.save(net.state_dict(), 'model.pth')
```
在训练过程中,我们记录了每个 epoch 的 loss 和准确率,并将它们存储在 `loss_list` 和 `acc_list` 中。然后使用 Matplotlib 库绘制出了相应的 loss 和准确率曲线。最后,我们使用 `torch.save()` 函数将模型保存到磁盘中,以便后续的推理和部署。
阅读全文