基于已训练好的CIFAR模型.pth文件将500张32*32像素的图片进行识别并将图片和标签组成键值对保存成torch.tensor格式的代码
时间: 2024-02-15 20:03:10 浏览: 74
好的,以下是基于已训练好的CIFAR模型.pth文件将500张32*32像素的图片进行识别并将图片和标签组成键值对保存成torch.tensor格式的代码:
```python
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
# 加载已训练好的模型
model = torch.load('cifar_model.pth')
# 定义数据转换
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加载测试图片并进行预测
data_dir = 'test_images'
file_list = os.listdir(data_dir)
results = {}
for file_name in file_list:
file_path = os.path.join(data_dir, file_name)
image = Image.open(file_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0)
output = model(image_tensor)
_, predicted = torch.max(output, 1)
results[file_name] = predicted.item()
# 将结果保存为torch.tensor格式
results_tensor = torch.tensor(list(results.items()))
# 打印结果
print(results_tensor)
```
这段代码中,我们首先加载了已训练好的模型,并定义了数据转换。然后,我们遍历了包含500张测试图片的文件夹,对每张图片进行预测,将图片名和预测标签组成键值对保存到results字典中。最后,我们将结果转化为torch.tensor格式,并打印输出。
阅读全文