vgg16预测单张图片分类可能性的代码pytorch
时间: 2023-05-31 07:05:10 浏览: 104
以下是使用PyTorch实现VGG16模型进行单张图片分类预测的代码:
```python
import torch
import torchvision.transforms as transforms
from PIL import Image
import json
# 加载类别名称
with open('imagenet_class_index.json', 'r') as f:
label_map = json.load(f)
# 定义预处理程序
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载模型
model = torch.hub.load('pytorch/vision', 'vgg16', pretrained=True)
# 设置模型为评估模式
model.eval()
# 加载图像
image = Image.open('test.jpg')
# 对图像进行预处理
image_tensor = transform(image)
# 添加一个维度,以匹配模型输入
image_tensor = image_tensor.unsqueeze(0)
# 使用模型进行预测
with torch.no_grad():
output = model(image_tensor)
# 获取预测结果
pred = output.argmax(dim=1)[0].item()
# 输出结果
print('预测结果:', label_map[str(pred)][1])
```
在代码中,我们首先加载了ImageNet数据集的标签映射,然后定义了一个预处理程序,用于对输入图像进行预处理。我们使用`torch.hub.load`函数加载了预训练的VGG16模型,并将其设置为评估模式。然后,我们加载了要进行预测的图像,并对其进行预处理。最后,我们使用模型对图像进行预测,并输出预测结果。
阅读全文