transformers做cifar10图像分类有loss曲线和准确率曲线并保存模型pytorch
时间: 2024-02-11 18:05:48 浏览: 102
利用pytorch对CIFAR数据进行图像分类(包含全套代码和10+个模型的实现)
5星 · 资源好评率100%
当使用transformers库进行CIFAR-10图像分类时,你可以使用PyTorch中提供的CrossEntropyLoss作为损失函数,并选择Adam或SGD作为优化器。以下是一个简单的示例代码,用于训练和保存模型,并记录训练过程中的损失和准确率。
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor, Normalize
from transformers import ViTModel
# 定义模型
model = ViTModel.from_pretrained('google/vit-base-patch16-224')
model.fc = nn.Linear(model.config.hidden_size, 10)
# 加载数据
train_data = CIFAR10(root='./data', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 统计损失和准确率
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# 输出训练结果
print('Epoch [%d/%d], Loss: %.4f, Accuracy: %.2f%%'
% (epoch+1, num_epochs, running_loss / len(train_loader), 100 * correct / total))
# 保存模型
torch.save(model.state_dict(), 'cifar10_vit.pth')
```
在训练过程中,你可以使用tensorboard等工具记录损失和准确率,并可视化损失和准确率曲线。在代码中增加以下内容即可:
```python
from torch.utils.tensorboard import SummaryWriter
# 定义tensorboard writer
writer = SummaryWriter(log_dir='./logs')
# 在每个epoch结束时记录损失和准确率
writer.add_scalar('Loss/train', running_loss / len(train_loader), epoch)
writer.add_scalar('Accuracy/train', 100 * correct / total, epoch)
# 关闭tensorboard writer
writer.close()
```
最后,你可以使用以下代码加载保存的模型,并进行测试:
```python
# 加载模型
model.load_state_dict(torch.load('cifar10_vit.pth'))
# 加载测试数据
test_data = CIFAR10(root='./data', train=False, transform=ToTensor(), download=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)
# 测试模型
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy: %.2f%%' % (100 * correct / total))
```
希望这些代码对你有所帮助!
阅读全文