如何使用`timm`库来训练一个图像分类模型?
时间: 2024-09-07 17:04:06 浏览: 105
`timm`(PyTorch Image Models)是一个用于计算机视觉领域的深度学习库,它提供了大量预训练模型以及快速训练功能。使用`timm`库来训练一个图像分类模型,可以遵循以下步骤:
1. 安装`timm`库:
```bash
pip install timm
```
2. 导入必要的库和模块:
```python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from timm import create_model
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
```
3. 准备数据集并进行数据预处理:
```python
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整图像大小以匹配模型输入尺寸
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 假设你的数据集分为训练集和验证集,并存储在相应的目录中
train_dataset = ImageFolder(root='path_to_train_dataset', transform=transform)
val_dataset = ImageFolder(root='path_to_val_dataset', transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False)
```
4. 创建模型实例并定义训练参数:
```python
model = create_model('resnet18', pretrained=False, num_classes=num_classes) # num_classes为类别数量
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
```
5. 训练模型:
```python
epochs = 10
for epoch in range(epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item()}")
# 验证模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in val_loader:
outputs = model(data)
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print(f"Validation Accuracy: {100 * correct / total}%")
```
6. 保存模型:
```python
torch.save(model.state_dict(), 'image_classification_model.pth')
```
阅读全文