transformer模型图像分类任务
时间: 2023-11-20 14:58:11 浏览: 103
Transformer模型在图像分类任务中也有广泛的应用。其中,Swin Transformer是一种新型的Transformer模型,它在图像分类任务中表现出色。下面是使用Swin Transformer进行图像分类任务的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy
from timm.optim import create_optimizer
from timm.scheduler import create_scheduler
# 定义超参数
batch_size = 64
num_epochs = 20
learning_rate = 0.001
num_classes = 12
# 定义数据增强和预处理
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载数据集
train_dataset = datasets.ImageFolder('train', transform=train_transforms)
val_dataset = datasets.ImageFolder('val', transform=val_transforms)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
# 定义模型
model = create_model('swin_base_patch4_window12_384', pretrained=True, num_classes=num_classes)
# 定义损失函数和优化器
criterion = LabelSmoothingCrossEntropy()
optimizer = create_optimizer(learning_rate, model)
scheduler = create_scheduler(optimizer, num_epochs, len(train_loader))
# 训练模型
for epoch in range(num_epochs):
model.train()
for images, labels in train_loader:
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images = images.cuda()
labels = labels.cuda()
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
.format(epoch+1, num_epochs, loss.item(), 100 * correct / total))
# 保存模型
torch.save(model.state_dict(), 'swin_transformer.pth')
```
该示例代码使用Swin Transformer模型对一个包含12个类别的图像分类数据集进行训练和验证,并将训练好的模型保存到本地。在训练过程中,使用了Label Smoothing Cross Entropy作为损失函数,并使用了学习率调度器来动态调整学习率。在验证过程中,计算了模型的准确率。
阅读全文