pytorch使用训练好的模型,读取文件夹中的图片进行预测,并保存结果
时间: 2024-02-17 18:02:46 浏览: 123
好的,下面是使用 PyTorch 加载已训练好的模型,读取文件夹中的图片进行预测,并保存结果的代码:
```python
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
# 加载模型
model = torch.load('model.pth', map_location=torch.device('cpu'))
# 定义变换
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 预测图片
image_folder = 'image_folder'
result_folder = 'result_folder'
if not os.path.exists(result_folder):
os.mkdir(result_folder)
for filename in os.listdir(image_folder):
image_path = os.path.join(image_folder, filename)
image = Image.open(image_path)
image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze(0)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output.data, 1)
predicted_class = predicted.item()
result_path = os.path.join(result_folder, filename)
with open(result_path, 'w') as f:
f.write(str(predicted_class))
```
其中,`model.pth` 是已训练好的模型文件,`image_folder` 是存放待预测图片的文件夹,`result_folder` 是存放预测结果的文件夹。
这段代码首先加载模型,然后定义了一系列变换,包括将图片调整为 256x256 的大小、裁剪中心的 224x224 区域、将图片转换为张量、以及对三个通道进行归一化。接着,代码遍历 `image_folder` 中的所有图片,对每一张图片进行预测,并将预测结果保存到 `result_folder` 中对应的文件中。最后,这段代码会生成一个名为 `result_folder` 的文件夹,其中包含与 `image_folder` 中图片数量相同的文件,每个文件中包含了对应图片的预测结果。
阅读全文