基于pytorch水果图像识别系统,数据集使用Fruits 360,并将训练好的模型应用到图片识别界面中,具体代码实现
时间: 2024-06-09 13:07:06 浏览: 181
好的,下面是基于PyTorch的水果图像识别系统的代码实现,包括数据集使用Fruits 360和将训练好的模型应用到图片识别界面中。
首先需要安装以下依赖库:
```python
pip install torch torchvision matplotlib numpy Pillow PyQt5
```
接下来,我们需要下载Fruits 360数据集。可以从Kaggle下载(https://www.kaggle.com/moltean/fruits),或者使用以下代码从GitHub上下载:
```python
!git clone https://github.com/Horea94/Fruit-Images-Dataset.git
```
下载完成后,我们需要对数据集进行预处理和划分。以下是完整的数据预处理和划分代码:
```python
import os
import shutil
import random
# 定义数据集路径
data_path = './Fruit-Images-Dataset'
# 定义训练集和测试集路径
train_path = './fruits-360/Training'
test_path = './fruits-360/Test'
# 定义训练集和测试集比例
train_ratio = 0.8
test_ratio = 0.2
# 获取所有类别
classes = os.listdir(data_path)
# 创建训练集和测试集目录
os.makedirs(train_path, exist_ok=True)
os.makedirs(test_path, exist_ok=True)
# 遍历每个类别
for cls in classes:
cls_path = os.path.join(data_path, cls)
imgs = os.listdir(cls_path)
num_imgs = len(imgs)
num_train = int(num_imgs * train_ratio)
num_test = num_imgs - num_train
# 创建训练集和测试集子目录
train_cls_path = os.path.join(train_path, cls)
test_cls_path = os.path.join(test_path, cls)
os.makedirs(train_cls_path, exist_ok=True)
os.makedirs(test_cls_path, exist_ok=True)
# 随机划分训练集和测试集
random.shuffle(imgs)
train_imgs = imgs[:num_train]
test_imgs = imgs[num_train:]
# 复制图片到训练集和测试集目录
for img in train_imgs:
src_path = os.path.join(cls_path, img)
dst_path = os.path.join(train_cls_path, img)
shutil.copy(src_path, dst_path)
for img in test_imgs:
src_path = os.path.join(cls_path, img)
dst_path = os.path.join(test_cls_path, img)
shutil.copy(src_path, dst_path)
```
接下来,我们需要定义模型。以下是使用ResNet-18作为模型的代码:
```python
import torch.nn as nn
import torchvision.models as models
class FruitClassifier(nn.Module):
def __init__(self, num_classes):
super(FruitClassifier, self).__init__()
self.model = models.resnet18(pretrained=True)
num_features = self.model.fc.in_features
self.model.fc = nn.Linear(num_features, num_classes)
def forward(self, x):
x = self.model(x)
return x
```
我们可以使用预训练的ResNet-18模型,并将其输出层替换为一个具有num_classes个输出的全连接层。在前向传递期间,我们只需调用self.model(x)。
接下来,我们需要定义训练和测试函数:
```python
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
def train(model, train_loader, optimizer, criterion):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 10 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(model, test_loader, criterion):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = model(data)
test_loss += criterion(output, target).item() * len(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_acc = 100. * correct / len(test_loader.dataset)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
test_loss, correct, len(test_loader.dataset), test_acc))
```
在训练函数中,我们首先将模型设置为训练模式(model.train())。然后对每个批次的数据进行前向传递、计算损失、反向传播和优化器更新。最后打印训练损失。
在测试函数中,我们首先将模型设置为评估模式(model.eval())。然后遍历测试集中的每个批次,进行前向传递和计算损失。最后计算测试集上的平均损失和准确率。
接下来,我们需要定义主函数来训练模型:
```python
def main():
# 定义超参数
batch_size = 64
num_epochs = 10
learning_rate = 0.001
num_classes = 131
# 定义数据增强和预处理
transform_train = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 加载训练集和测试集
train_dataset = datasets.ImageFolder(train_path, transform=transform_train)
test_dataset = datasets.ImageFolder(test_path, transform=transform_test)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 定义模型、损失函数和优化器
model = FruitClassifier(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(1, num_epochs + 1):
train(model, train_loader, optimizer, criterion)
test(model, test_loader, criterion)
```
在主函数中,我们首先定义超参数,包括批量大小、训练周期和学习率。然后定义数据增强和预处理。接下来加载训练集和测试集,并使用torch.utils.data.DataLoader进行批量处理和数据加载。然后定义模型、损失函数和优化器。最后训练模型。
最后,我们需要将训练好的模型应用到图片识别界面中。以下是完整的代码:
```python
import sys
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QPushButton, QVBoxLayout, QHBoxLayout, QFileDialog
from PyQt5.QtGui import QPixmap
from PyQt5.QtCore import Qt
from PIL import Image
class MainWindow(QWidget):
def __init__(self):
super().__init__()
self.initUI()
# 加载训练好的模型
self.model = FruitClassifier(num_classes=131)
self.model.load_state_dict(torch.load('./fruits-360-resnet18.pth', map_location=torch.device('cpu')))
self.model.eval()
# 定义类别名称
self.classes = os.listdir(train_path)
# 定义预处理和转换
self.transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def initUI(self):
# 定义界面控件
self.label_img = QLabel(self)
self.label_img.setAlignment(Qt.AlignCenter)
self.label_pred = QLabel(self)
self.label_pred.setAlignment(Qt.AlignCenter)
self.button_browse = QPushButton('Browse', self)
self.button_browse.clicked.connect(self.browseImage)
# 定义布局
hbox = QHBoxLayout()
hbox.addWidget(self.button_browse)
vbox = QVBoxLayout()
vbox.addWidget(self.label_img)
vbox.addWidget(self.label_pred)
vbox.addLayout(hbox)
self.setLayout(vbox)
# 定义窗口
self.setGeometry(100, 100, 400, 400)
self.setWindowTitle('Fruit Image Classifier')
self.show()
def browseImage(self):
# 打开文件对话框
file_dialog = QFileDialog()
file_dialog.setFileMode(QFileDialog.ExistingFile)
file_dialog.setNameFilter('Images (*.png *.xpm *.jpg *.bmp)')
if file_dialog.exec_():
file_path = file_dialog.selectedFiles()[0]
self.classifyImage(file_path)
def classifyImage(self, file_path):
# 加载图片并进行预处理和转换
img = Image.open(file_path)
img = self.transform(img)
img = img.unsqueeze(0)
# 进行预测
with torch.no_grad():
output = self.model(img)
pred = output.argmax(dim=1).item()
pred_name = self.classes[pred]
pred_prob = F.softmax(output, dim=1)[0][pred].item()
# 显示图片和预测结果
pixmap = QPixmap(file_path).scaledToWidth(300)
self.label_img.setPixmap(pixmap)
self.label_pred.setText('Prediction: {} ({:.1f}%)'.format(pred_name, pred_prob * 100))
if __name__ == '__main__':
app = QApplication(sys.argv)
window = MainWindow()
sys.exit(app.exec_())
```
在图片识别界面中,我们使用QLabel来显示图片和预测结果,并使用QPushButton来打开文件对话框。当用户选择一个图片时,我们将其路径传递给classifyImage函数。在classifyImage函数中,我们首先加载图片并进行预处理和转换。然后使用训练好的模型进行预测,并显示预测结果。
最后,我们需要保存训练好的模型:
```python
# 保存模型
torch.save(model.state_dict(), './fruits-360-resnet18.pth')
```
现在,我们可以运行主函数来训练模型并保存模型。然后运行图片识别界面来测试模型。
阅读全文