基于pytorch的水果图像识别与分类系统的设计与实现,数据集使用Fruits 360,要求编写转换函数对数据集进行数据增强,模型要实现标准量化和批量归一化,并且实现权重衰减,梯度裁剪和Adam优化,最后将训练好的模型保存下来,并利用该模型实现一个有前后端的分类系统
时间: 2023-06-17 16:04:11 浏览: 422
好的,这是一个比较复杂的项目,需要一步一步实现。我们按照以下步骤进行:
1. 下载Fruits 360数据集,并进行数据增强
首先我们需要下载Fruits 360数据集,可以从官网(https://www.kaggle.com/moltean/fruits)或者Github(https://github.com/Horea94/Fruit-Images-Dataset)上下载。下载完成后,我们需要对数据集进行数据增强,以提高模型的鲁棒性和泛化能力。数据增强可以使用torchvision中的transforms模块来实现。
```python
import torchvision.transforms as transforms
# 定义数据增强操作
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪出大小为224的图像
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(30), # 随机旋转(-30, 30)度
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
])
test_transforms = transforms.Compose([
transforms.Resize(256), # 调整到256大小
transforms.CenterCrop(224), # 中心裁剪出大小为224的图像
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化
])
```
2. 加载数据集
我们需要使用torch.utils.data中的DataLoader来加载数据集,以便于训练模型。
```python
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
# 定义数据集类
class Fruits360Dataset(Dataset):
def __init__(self, root_dir, transform=None):
self.dataset = ImageFolder(root_dir, transform=transform)
self.classes = self.dataset.classes
self.class_to_idx = self.dataset.class_to_idx
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
# 加载训练集和测试集
train_dataset = Fruits360Dataset("fruits-360/Training", transform=train_transforms)
test_dataset = Fruits360Dataset("fruits-360/Test", transform=test_transforms)
# 定义DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
```
3. 构建模型
我们使用ResNet50作为我们的模型,同时使用标准量化和批量归一化来提高模型的训练效果。
```python
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
# 定义模型
class FruitClassifier(nn.Module):
def __init__(self, num_classes):
super(FruitClassifier, self).__init__()
self.backbone = models.resnet50(pretrained=True)
self.backbone.fc = nn.Linear(2048, num_classes)
def forward(self, x):
x = self.backbone(x)
return x
# 实例化模型
model = FruitClassifier(num_classes=len(train_dataset.classes))
```
4. 定义损失函数、优化器和学习率调度器
我们使用交叉熵损失函数作为我们的损失函数,Adam优化器作为我们的优化器,并使用学习率调度器来动态调整学习率。
```python
import torch.optim as optim
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)
# 定义学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
```
5. 训练模型
我们使用权重衰减和梯度裁剪来防止模型过拟合,并使用训练集和测试集来训练和评估模型。
```python
# 定义训练函数
def train(model, train_loader, criterion, optimizer, epoch, device):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
optimizer.step()
train_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
train_loss /= len(train_loader)
acc = 100. * correct / total
print('Epoch: {} Train Loss: {:.3f} Train Acc: {:.3f}'.format(epoch, train_loss, acc))
return train_loss, acc
# 定义测试函数
def test(model, test_loader, criterion, device):
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(test_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
test_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
test_loss /= len(test_loader)
acc = 100. * correct / total
print('Test Loss: {:.3f} Test Acc: {:.3f}'.format(test_loss, acc))
return test_loss, acc
# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
best_acc = 0
for epoch in range(1, 21):
train_loss, train_acc = train(model, train_loader, criterion, optimizer, epoch, device)
test_loss, test_acc = test(model, test_loader, criterion, device)
scheduler.step(test_loss)
# 保存最好的模型
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), "fruit_classifier.pth")
```
6. 前后端分类系统
我们使用Flask作为我们的后端框架,使用HTML和JavaScript作为我们的前端页面。我们将训练好的模型加载到后端,并使用POST请求将前端上传的图片发送到后端进行预测。
```python
from flask import Flask, request, jsonify
from PIL import Image
import io
import base64
# 加载模型
model = FruitClassifier(num_classes=len(train_dataset.classes))
model.load_state_dict(torch.load('fruit_classifier.pth', map_location=device))
model.eval()
app = Flask(__name__)
@app.route('/')
def index():
return '''
<!doctype html>
<html>
<body>
<h2>Upload a fruit image</h2>
<form id="my-form">
<input type="file" id="my-file" name="my-file">
<button type="submit">Submit</button>
</form>
<div id="result"></div>
<script>
const form = document.querySelector("#my-form");
const resultDiv = document.querySelector("#result");
form.addEventListener("submit", function(event) {
event.preventDefault();
const fileInput = document.querySelector("#my-file");
const file = fileInput.files[0];
const reader = new FileReader();
reader.readAsDataURL(file);
reader.onload = function() {
const base64Data = reader.result.split(",")[1];
const url = "http://localhost:5000/predict";
const data = { image: base64Data };
fetch(url, {
method: "POST",
body: JSON.stringify(data),
headers: {
"Content-Type": "application/json"
}
})
.then(response => response.json())
.then(result => {
resultDiv.innerHTML = "<h2>Prediction: " + result.prediction + "</h2>";
});
};
});
</script>
</body>
</html>
'''
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()
image_data = base64.b64decode(data['image'])
image = Image.open(io.BytesIO(image_data))
image = test_transforms(image).unsqueeze(0).to(device)
output = model(image)
prediction = train_dataset.classes[output.argmax().item()]
return jsonify({'prediction': prediction})
if __name__ == '__main__':
app.run()
```
运行以上代码,我们就可以在浏览器中访问http://localhost:5000/,上传一张水果图片进行分类了。
阅读全文