基于pytorch使用VGG模型实现对cifar数据集的分类
时间: 2024-04-12 14:31:59 浏览: 180
利用pytorch对CIFAR数据进行图像分类(包含全套代码和10+个模型的实现)
5星 · 资源好评率100%
下面是一个使用PyTorch中的VGG模型对CIFAR数据集进行分类的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import vgg16
# 定义数据预处理的转换
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 加载训练集和测试集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 加载预训练的VGG模型
model = vgg16(pretrained=True)
# 修改最后一层全连接层的输出维度,使其适应CIFAR数据集的类别数(10个类别)
num_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_features, 10)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
num_epochs = 10
for epoch in range(num_epochs):
running_loss = 0.0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1} - Training loss: {running_loss / len(train_loader)}")
# 在测试集上评估模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f"Accuracy on test set: {accuracy}%")
```
在上述代码中,首先定义了数据预处理的转换,将图像调整大小为224x224并进行归一化处理。然后使用`datasets.CIFAR10`类加载CIFAR-10数据集。接着创建数据加载器,指定数据集和批次大小。
然后使用`torchvision.models.vgg16`加载预训练的VGG模型,并修改最后一层全连接层的输出维度,以适应CIFAR数据集的类别数。定义损失函数和优化器。
接下来,将模型移动到GPU(如果可用),并进行训练。训练过程中,迭代训练数据并计算损失,并进行反向传播和参数更新。
训练完成后,将模型设置为评估模式,并在测试集上进行评估。计算模型在测试集上的准确率并打印结果。
请确保已经安装了PyTorch和torchvision库,并将数据集下载到指定的路径(在上述代码中为'./data')。
阅读全文