python vgg16分类模型测试代码
时间: 2024-12-29 07:25:50 浏览: 5
### Python VGG16 图像分类测试代码示例
为了展示如何使用预训练的VGG16模型进行图像分类,下面提供了一个完整的Python代码示例。此代码加载了预训练的VGG16模型,并对其进行了评估。
```python
import torch
from torchvision import models, transforms
from PIL import Image
import json
# 加载ImageNet类别标签
with open('imagenet_class_index.json') as f:
class_idx = json.load(f)
idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
# 定义数据转换操作
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 加载预训练的VGG16模型并设置为评估模式
vgg16 = models.vgg16(pretrained=True)
vgg16.eval()
def predict(image_path):
"""预测单张图片"""
input_image = Image.open(image_path)
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0)
# 使用GPU加速计算(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg16.to(device)
input_batch = input_batch.to(device)
with torch.no_grad():
output = vgg16(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
top5_prob, top5_catid = torch.topk(probabilities, 5)
results = []
for i in range(top5_prob.size(0)):
results.append((top5_prob[i].item(), idx2label[top5_catid[i]]))
return results
image_path = 'example.jpg'
predictions = predict(image_path)
for prob, label in predictions:
print(f"{label}: {prob:.2%}")
```
这段代码首先定义了一系列用于处理输入图像的数据变换方法[^1]。接着初始化了一个预训练好的VGG16模型实例,并将其置于评估模式下以禁用dropout等训练特有的行为。`predict()`函数接收一张图片路径作为参数,返回该图片最有可能属于的五个类别的概率及其名称列表。最后部分展示了如何调用这个函数并对结果进行打印输出。
阅读全文