基于pytorch的水果图像识别与分类系统的设计与实现,数据集使用Fruits 360,要求编写转换函数对数据集进行数据增强,模型要实现标准量化和批量归一化,并且实现权重衰减,梯度裁剪和Adam优化,最后将训练好的模型保存下来,并利用该模型实现一个网页端的分类系统
时间: 2023-06-17 19:04:40 浏览: 97
水果识别训练模型与数据集
这是一个相对复杂的项目,需要进行以下几个步骤:
1. 数据集获取与预处理:从Fruits 360数据集中下载数据,并将其转换为pytorch中的Dataset格式。同时,为了避免过拟合,需要对数据进行数据增强,比如随机裁剪、随机旋转、随机翻转、随机调整亮度、对比度等。
2. 模型设计与训练:根据数据集的特点设计一个适合的卷积神经网络模型。在训练过程中,需要进行标准量化和批量归一化、权重衰减、梯度裁剪和Adam优化等操作,以提高模型的精度和泛化能力。
3. 模型保存与部署:在训练完模型后,将其保存到本地文件中,以便后续使用。然后,使用Flask等Web框架,将模型部署到Web服务器上,实现一个网页端的分类系统。
下面是一个简单的代码示例,帮助你更好地理解该项目的实现过程:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.utils.data import Dataset
from PIL import Image
# 定义数据增强函数
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.5, contrast=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
transform_test = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 定义数据集类
class FruitsDataset(Dataset):
def __init__(self, root, transform=None):
self.root = root
self.transform = transform
self.filenames = []
self.labels = []
self.classes = []
self.class_to_idx = {}
for i, class_name in enumerate(sorted(os.listdir(root))):
self.class_to_idx[class_name] = i
self.classes.append(class_name)
class_dir = os.path.join(root, class_name)
for filename in os.listdir(class_dir):
self.filenames.append(os.path.join(class_dir, filename))
self.labels.append(i)
def __getitem__(self, index):
filename = self.filenames[index]
img = Image.open(filename).convert('RGB')
label = self.labels[index]
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.filenames)
# 定义卷积神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.ReLU()
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(128 * 4 * 4, 256)
self.dropout = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(256, 120)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.pool3(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
x = self.dropout(x)
x = self.fc2(x)
return x
# 定义训练函数
def train(model, train_loader, criterion, optimizer):
model.train()
running_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(train_loader.dataset)
return epoch_loss
# 定义测试函数
def test(model, test_loader, criterion):
model.eval()
running_loss = 0.0
running_corrects = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
_, preds = torch.max(outputs, 1)
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(test_loader.dataset)
epoch_acc = running_corrects.double() / len(test_loader.dataset)
return epoch_loss, epoch_acc
if __name__ == '__main__':
# 加载数据集
train_set = FruitsDataset('fruits-360/Training', transform=transform_train)
test_set = FruitsDataset('fruits-360/Test', transform=transform_test)
# 划分训练集和验证集
train_size = int(0.8 * len(train_set))
valid_size = len(train_set) - train_size
train_set, valid_set = random_split(train_set, [train_size, valid_size])
# 定义数据加载器
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_set, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=4)
# 定义模型、损失函数和优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
# 训练模型
best_acc = 0.0
for epoch in range(10):
train_loss = train(model, train_loader, criterion, optimizer)
valid_loss, valid_acc = test(model, valid_loader, criterion)
print('Epoch: {} Train Loss: {:.4f} Valid Loss: {:.4f} Valid Acc: {:.4f}'.format(
epoch + 1, train_loss, valid_loss, valid_acc))
if valid_acc > best_acc:
best_acc = valid_acc
torch.save(model.state_dict(), 'fruits_model.pt')
# 加载最佳模型
model.load_state_dict(torch.load('fruits_model.pt'))
# 在测试集上评估模型
test_loss, test_acc = test(model, test_loader, criterion)
print('Test Loss: {:.4f} Test Acc: {:.4f}'.format(test_loss, test_acc))
```
最后,你可以使用Flask框架将模型部署到Web服务器上,实现一个网页端的分类系统。具体步骤如下:
1. 安装Flask框架:```pip install Flask```
2. 创建一个app.py文件,并添加以下代码:
```python
from flask import Flask, request, jsonify
from PIL import Image
import io
import torch
import torchvision.transforms as transforms
app = Flask(__name__)
# 加载模型
model = Net()
model.load_state_dict(torch.load('fruits_model.pt'))
model.eval()
# 定义数据预处理函数
def preprocess(image_bytes):
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
image = Image.open(io.BytesIO(image_bytes))
image = transform(image).unsqueeze(0)
return image
# 定义分类函数
def classify(image_bytes):
image = preprocess(image_bytes)
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output, 1)
return predicted.item()
# 定义路由
@app.route('/', methods=['GET'])
def index():
return 'Hello, World!'
@app.route('/predict', methods=['POST'])
def predict():
if 'file' not in request.files:
return 'No file uploaded!'
file = request.files['file']
image_bytes = file.read()
class_id = classify(image_bytes)
class_name = train_set.classes[class_id]
return jsonify({'class_id': class_id, 'class_name': class_name})
if __name__ == '__main__':
app.run()
```
3. 在命令行中运行以下命令启动Web服务器:
```bash
export FLASK_APP=app.py
flask run
```
4. 在浏览器中访问http://localhost:5000/predict,上传一张水果图片,即可得到该图片的分类结果。
阅读全文