写一个基于pytorch水果图像识别系统的代码
时间: 2023-10-28 08:06:07 浏览: 154
好的,以下是基于 PyTorch 的水果图像识别系统代码:
```python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
# 加载预训练模型
model = torch.load('fruit_classifier.pth', map_location=torch.device('cpu'))
# 定义标签
labels = ['apple', 'banana', 'orange']
# 定义图像预处理方法
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 定义图像分类函数
def classify(image_path):
# 加载图像
image = Image.open(image_path)
# 预处理图像
image_tensor = preprocess(image)
image_tensor = image_tensor.unsqueeze(0)
# 使用模型进行预测
output = model(image_tensor)
_, predicted = torch.max(output, 1)
# 输出分类结果
return labels[predicted[0]]
```
使用方法:
1. 准备好 PyTorch 训练好的模型文件 `fruit_classifier.pth`。
2. 将待分类的水果图像放入 `image_path` 变量指定的目录中。
3. 调用 `classify(image_path)` 方法,即可输出图片的水果分类结果。
注意:该代码仅适用于分类苹果、香蕉、橘子三种常见水果的图像,对于其它类型的图像可能无法正确分类。
阅读全文