CPU跑transformers做cifar10图像分类任务有loss曲线和准确率曲线并保存模型pytorch
时间: 2024-02-11 10:08:05 浏览: 31
好的,下面是一个使用 `transformers` 库和 PyTorch 进行 CIFAR-10 图像分类任务的示例代码,其中包括训练模型、绘制训练曲线、保存模型等步骤。由于 CIFAR-10 数据集比较小,因此我们可以在 CPU 上训练模型。代码如下:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from transformers import ViTModel, ViTConfig
import matplotlib.pyplot as plt
# 定义超参数
batch_size = 100
learning_rate = 0.001
num_epochs = 10
# 加载 CIFAR-10 数据集
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
# 初始化 ViT 模型和配置
config = ViTConfig(image_size=32, patch_size=4, num_classes=10, dim=64, depth=6, heads=8, mlp_dim=128)
vit = ViTModel(config)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit.parameters(), lr=learning_rate)
# 训练模型
train_loss_list = []
train_acc_list = []
test_loss_list = []
test_acc_list = []
for epoch in range(num_epochs):
train_loss = 0.0
train_total = 0
train_correct = 0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = vit(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
train_total += labels.size(0)
train_correct += (predicted == labels).sum().item()
train_loss_list.append(train_loss / len(trainloader))
train_acc_list.append(train_correct / train_total)
print('Epoch %d, training loss: %.3f, training accuracy: %.3f' %
(epoch + 1, train_loss / len(trainloader), train_correct / train_total))
test_loss = 0.0
test_total = 0
test_correct = 0
with torch.no_grad():
for data in testloader:
inputs, labels = data
outputs = vit(inputs)
loss = criterion(outputs, labels)
test_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
test_total += labels.size(0)
test_correct += (predicted == labels).sum().item()
test_loss_list.append(test_loss / len(testloader))
test_acc_list.append(test_correct / test_total)
print('Epoch %d, testing loss: %.3f, testing accuracy: %.3f' %
(epoch + 1, test_loss / len(testloader), test_correct / test_total))
# 绘制训练曲线
plt.plot(range(1, num_epochs + 1), train_loss_list, 'b-', label='Training Loss')
plt.plot(range(1, num_epochs + 1), test_loss_list, 'r-', label='Testing Loss')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
plt.plot(range(1, num_epochs + 1), train_acc_list, 'b-', label='Training Accuracy')
plt.plot(range(1, num_epochs + 1), test_acc_list, 'r-', label='Testing Accuracy')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.show()
# 保存模型
torch.save(vit.state_dict(), 'vit_cifar10.pth')
```
在运行完上述代码后,您将得到训练和测试的 loss 曲线和准确率曲线,并且训练好的模型将保存在当前目录下的 `vit_cifar10.pth` 文件中。您可以使用以下代码加载保存的模型:
```python
vit = ViTModel(config)
vit.load_state_dict(torch.load('vit_cifar10.pth'))
vit.eval()
```
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)