pytorh 训练过程 获取单张图片的预测结果
时间: 2023-05-30 22:05:36 浏览: 104
在 PyTorch 训练过程中,获取单张图片的预测结果需要进行以下步骤:
1. 加载模型:使用 `torch.load()` 函数加载训练好的模型。
2. 加载图片:使用 PIL 或 OpenCV 等库加载单张图片,并进行预处理,如调整大小、变换通道格式等。
3. 进行预测:将预处理后的图片输入到模型中,使用 `model.eval()` 函数将模型设置为评估模式,然后调用 `model()` 函数进行预测。预测结果通常是一个向量,可以使用 `torch.max()` 函数获取预测结果中概率最大的类别。
4. 展示结果:根据预测结果,可以展示图片和预测类别,以及预测概率等信息。
下面是一个示例代码:
```python
import torch
import torchvision.transforms as transforms
from PIL import Image
# 加载模型
model = torch.load('model.pth')
# 加载图片并进行预处理
img = Image.open('test.jpg')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = transform(img)
img = img.unsqueeze(0)
# 进行预测
model.eval()
output = model(img)
_, predicted = torch.max(output, 1)
# 展示结果
print('Predicted class:', predicted.item())
```
阅读全文