怎么调用pytorch预训练的mobileNet模型实现图像10分类?
时间: 2023-12-03 21:45:06 浏览: 144
可以使用 PyTorch 提供的 `torchvision.models` 模块中的 `mobilenet_v2` 模型来实现图像 10 分类。以下是一个简单的例子:
```python
import torch
from torchvision import models
# 加载预训练的 MobileNetV2 模型
model = models.mobilenet_v2(pretrained=True)
# 将模型设置为评估模式
model.eval()
# 加载并预处理图像,这里假设图像路径为 image.jpg
from PIL import Image
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = Image.open("image.jpg")
image_tensor = transform(image).unsqueeze_(0)
# 使用模型进行预测
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output.data, 1)
# 输出预测结果
print("Predicted class:", predicted.item())
```
在这个例子中,我们首先加载预训练的 MobileNetV2 模型,并将其设置为评估模式。然后,我们使用 `PIL` 库加载图像,并使用 `torchvision.transforms` 模块中的 `Compose` 函数定义图像预处理的操作,包括将图像缩放到 256x256 像素、居中裁剪到 224x224 像素、将图像转换为张量,并对张量进行归一化。接下来,我们将预处理的图像传递给模型,使用 `torch.no_grad()` 块禁用梯度计算,使用 `torch.max()` 函数获取预测结果,并将预测结果输出到控制台。
阅读全文