测试一个图的分类正确率,Pytorch vgg16
时间: 2023-05-30 15:08:01 浏览: 128
可以使用以下代码测试一个图的分类正确率:
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练的VGG16模型
model = models.vgg16(pretrained=True)
model.eval()
# 图像预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载测试图像
img = Image.open('test_image.jpg')
# 预处理图像数据
img_tensor = transform(img)
# 添加批次维度
img_tensor = img_tensor.unsqueeze(0)
# 使用模型进行预测
with torch.no_grad():
output = model(img_tensor)
# 获取预测结果
_, predicted = torch.max(output.data, 1)
# 输出分类结果
print('Predicted class: ', predicted.item())
```
这里假设测试图像为`test_image.jpg`,需要将其放在脚本所在的目录下。在运行脚本之前,需要确保已经安装了PyTorch和Pillow库。
该脚本加载了预训练的VGG16模型,并对测试图像进行了预处理。然后使用模型进行预测,并输出分类结果。
阅读全文