transformers做cifar10图像分类任务保存loss曲线和准确率曲线并保存模型pytorch
时间: 2024-02-11 13:08:36 浏览: 153
由于`transformers`主要是用于自然语言处理的,因此不能直接用于图像分类任务,但可以使用它的预训练模型进行特征提取。以下是使用`transformers`中的预训练模型进行CIFAR-10图像分类的代码,并保存loss曲线、准确率曲线和模型。
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from transformers import ViTModel
# 定义超参数
input_size = 32
num_classes = 10
batch_size = 100
num_epochs = 5
learning_rate = 0.001
# 加载数据集并进行预处理
transform = transforms.Compose(
[transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
shuffle=False)
# 定义ViT模型
model = ViTModel.from_pretrained('google/vit-base-patch16-224').to(device)
model.classifier = nn.Linear(768, num_classes).to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 记录训练过程中的loss和accuracy
train_loss_list = []
test_loss_list = []
train_acc_list = []
test_acc_list = []
# 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
train_loss = 0
train_total = 0
train_correct = 0
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 记录loss和accuracy
train_loss += loss.item() * labels.size(0)
_, predicted = torch.max(outputs.data, 1)
train_total += labels.size(0)
train_correct += (predicted == labels).sum().item()
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# 在测试集上计算loss和accuracy
test_loss = 0
test_total = 0
test_correct = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
test_loss += loss.item() * labels.size(0)
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()
# 计算平均loss和accuracy
train_loss = train_loss / train_total
test_loss = test_loss / test_total
train_acc = 100 * train_correct / train_total
test_acc = 100 * test_correct / test_total
train_loss_list.append(train_loss)
test_loss_list.append(test_loss)
train_acc_list.append(train_acc)
test_acc_list.append(test_acc)
print('Epoch [{}/{}], Train Loss: {:.4f}, Train Accuracy: {:.2f}%, Test Loss: {:.4f}, Test Accuracy: {:.2f}%'
.format(epoch+1, num_epochs, train_loss, train_acc, test_loss, test_acc))
# 保存loss曲线和准确率曲线
plt.plot(train_loss_list, label='Train Loss')
plt.plot(test_loss_list, label='Test Loss')
plt.legend()
plt.savefig('loss.png')
plt.clf()
plt.plot(train_acc_list, label='Train Accuracy')
plt.plot(test_acc_list, label='Test Accuracy')
plt.legend()
plt.savefig('accuracy.png')
# 保存模型
torch.save(model.state_dict(), 'model.pth')
```
这里我们使用了ViT模型进行特征提取,并通过全连接层进行分类。在训练过程中,我们记录了训练和测试的loss和accuracy,并保存了loss曲线和准确率曲线。在最后,我们保存了训练好的模型。
阅读全文