pytorch使用已经训练好的库MobileNet v2识别图片,代码
时间: 2023-09-29 14:09:30 浏览: 171
下面是使用已经训练好的库MobileNet v2识别图片的代码:
```python
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
# 加载 MobileNet v2 模型
model = models.mobilenet_v2(pretrained=True)
# 定义图片变换方式
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载图片
img = Image.open("test.jpg")
# 对图片进行变换
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
# 计算模型输出
model.eval()
out = model(batch_t)
# 获取输出中的最大值及其索引
_, index = torch.max(out, 1)
# 加载预测标签文件
with open("imagenet_classes.txt") as f:
classes = [line.strip() for line in f.readlines()]
# 打印预测结果
print(classes[index])
```
其中,`models.mobilenet_v2(pretrained=True)`加载了已经经过 ImageNet 数据集训练好的 MobileNet v2 模型。变换方式`transform`将图片变换为模型输入所需要的格式。`torch.unsqueeze(img_t, 0)`将图片添加了一个 batch 维度,使其变为 4 维张量,可以输入到模型中进行计算。`torch.max(out, 1)`获取输出中的最大值及其索引,即为预测结果。最后,将预测结果与预测标签文件中的类别对应起来,即可输出预测结果。
阅读全文