使用 Python 来搭建后端接口,将训练好的模型加载进来,当用户上传一张水果照片时,调用接口,根据图像识别出水果种类,返回识别出的水果类型和信度值给前端的代码
时间: 2024-05-13 16:15:26 浏览: 70
好的,以下是使用 Flask 框架搭建后端接口的示例代码:
```python
from flask import Flask, request, jsonify
import torch
import torchvision.transforms as transforms
from PIL import Image
app = Flask(__name__)
# 加载训练好的模型
model = torch.load('fruit_classifier.pth', map_location=torch.device('cpu'))
model.eval()
# 定义分类标签
labels = ['apple', 'banana', 'orange']
# 定义图像预处理函数
def preprocess(image):
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
image = image.unsqueeze(0)
return image
# 定义接口
@app.route('/classify', methods=['POST'])
def classify():
# 获取上传的图片
image_file = request.files['image']
image = Image.open(image_file.stream)
# 图像分类
image = preprocess(image)
with torch.no_grad():
output = model(image)
probs = torch.softmax(output, dim=1)
confidence, predicted = torch.max(probs, dim=1)
label = labels[predicted.item()]
# 返回分类结果
result = {
'label': label,
'confidence': confidence.item()
}
return jsonify(result)
if __name__ == '__main__':
app.run()
```
这段代码使用了 PyTorch 框架来加载训练好的模型,使用了 Flask 框架来搭建接口,并使用了 Pillow 库来读取和处理图像。当用户上传一张水果照片时,调用 `/classify` 接口,对图像进行分类,并返回识别出的水果类型和信度值。
阅读全文