使用PyTorch中的VGG16模型对预训练的VGG-16网络进行图像分类,请给出完整python代码
时间: 2023-06-10 07:08:03 浏览: 161
好的,以下是使用PyTorch中的VGG16模型对预训练的VGG-16网络进行图像分类的完整Python代码:
```
import torch
import torchvision
import torchvision.transforms as transforms
# 加载预训练的VGG-16模型
vgg16 = torchvision.models.vgg16(pretrained=True)
# 设置要进行图像分类的数据集的路径
data_path = 'path/to/your/dataset'
# 定义数据预处理的操作
transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# 加载数据集
dataset = torchvision.datasets.ImageFolder(root=data_path, transform=transform)
# 定义数据加载器
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)
# 定义分类标签
classes = dataset.classes
# 设定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 将模型移动到设备上
vgg16.to(device)
# 开始测试
with torch.no_grad():
for i, data in enumerate(dataloader, 0):
# 获取图像和标签数据
images, labels = data[0].to(device), data[1].to(device)
# 运行VGG-16模型进行分类
outputs = vgg16(images)
# 获取预测结果
_, predicted = torch.max(outputs, 1)
# 显示预测结果
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
```
阅读全文