用jupyter实现一个图片识别界面,算法模型为自定义,类型为pth
时间: 2023-06-17 15:02:08 浏览: 295
要实现一个图片识别界面,我们可以使用 Python 的 Flask 框架来搭建一个 Web 应用程序。在 Web 应用程序中调用 PyTorch 模型进行图片识别。
首先,我们需要安装必要的库:Flask、PyTorch、Pillow。
```bash
pip install flask torch torchvision pillow
```
接下来,我们需要准备一个 PyTorch 模型文件,格式为 pth。这里我们以 ResNet50 为例,对 CIFAR10 数据集进行分类训练,得到一个分类模型。
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# 定义模型
class ResNet50(nn.Module):
def __init__(self, num_classes=10):
super(ResNet50, self).__init__()
self.resnet50 = torchvision.models.resnet50(pretrained=False)
self.fc = nn.Linear(1000, num_classes)
def forward(self, x):
x = self.resnet50(x)
x = self.fc(x)
return x
# 加载数据集
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)
# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = ResNet50(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)
for epoch in range(200):
net.train()
train_loss = 0.0
train_acc = 0.0
for i, (inputs, labels) in enumerate(trainloader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * inputs.size(0)
_, preds = torch.max(outputs.data, 1)
train_acc += torch.sum(preds == labels.data)
scheduler.step()
train_loss = train_loss / len(trainset)
train_acc = train_acc.double() / len(trainset)
print('Epoch: {} | Loss: {:.4f} | Acc: {:.4f}'.format(epoch+1, train_loss, train_acc))
# 保存模型
torch.save(net.state_dict(), 'resnet50_cifar10.pth')
```
接下来,我们编写一个 Flask 应用程序,实现图片上传和模型预测的功能。在应用程序中,我们使用 Flask 提供的 request 对象获取上传的图片文件,然后使用 PIL 库将图片转换为 PyTorch 的张量,最后调用 PyTorch 模型进行预测,并返回预测结果。
```python
from flask import Flask, request, jsonify
from PIL import Image
import torch
import torchvision.transforms as transforms
app = Flask(__name__)
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet50(num_classes=10).to(device)
model.load_state_dict(torch.load('resnet50_cifar10.pth'))
model.eval()
# 图片预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# 预测函数
def predict(image):
image_tensor = transform(image).unsqueeze(0)
image_tensor = image_tensor.to(device)
with torch.no_grad():
outputs = model(image_tensor)
_, preds = torch.max(outputs, 1)
return preds.item()
# Flask 路由
@app.route('/', methods=['GET', 'POST'])
def index():
if request.method == 'POST':
# 获取上传的文件
file = request.files['file']
# 将文件转换为 PIL 图片
image = Image.open(file.stream)
# 使用模型进行预测
result = predict(image)
# 返回预测结果
return jsonify({'result': result})
else:
# 显示图片上传界面
return '''
<!doctype html>
<html>
<body>
<h1>图片识别</h1>
<form method="post" enctype="multipart/form-data">
<input type="file" name="file">
<input type="submit" value="上传">
</form>
</body>
</html>
'''
if __name__ == '__main__':
app.run()
```
在终端中运行 Flask 应用程序:
```bash
export FLASK_APP=app.py
flask run
```
在浏览器中访问 http://127.0.0.1:5000,即可上传一张图片进行识别。
阅读全文