基于深度学习实现垃圾分类,该如何进行数据训练的源代码
时间: 2023-12-24 13:02:14 浏览: 28
下面是一个基于深度学习实现垃圾分类的数据训练的源代码示例,使用Keras框架和TensorFlow后端:
```
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 定义数据集路径和大小
train_dir = 'path/to/training/dataset'
val_dir = 'path/to/validation/dataset'
train_size = 2000
val_size = 800
# 数据增强
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
val_datagen = ImageDataGenerator(rescale=1./255)
# 从文件夹中读取数据并进行数据增强
train_generator = train_datagen.flow_from_directory(train_dir,
target_size=(150, 150),
batch_size=32,
class_mode='categorical')
val_generator = val_datagen.flow_from_directory(val_dir,
target_size=(150, 150),
batch_size=32,
class_mode='categorical')
# 构建模型
model = keras.Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
MaxPooling2D(2, 2),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Conv2D(128, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Conv2D(128, (3, 3), activation='relu'),
MaxPooling2D(2, 2),
Flatten(),
Dropout(0.5),
Dense(512, activation='relu'),
Dense(4, activation='softmax')
])
# 编译模型
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.RMSprop(lr=1e-4),
metrics=['acc'])
# 训练模型
history = model.fit_generator(train_generator,
steps_per_epoch=train_size//32,
epochs=100,
validation_data=val_generator,
validation_steps=val_size//32)
# 保存模型
model.save('garbage_classification_model.h5')
```
在上述代码中,我们首先使用`ImageDataGenerator`来定义数据增强的方式,包括旋转、平移、剪切、缩放和水平翻转等操作。然后使用`flow_from_directory`方法从文件夹中读取数据并进行数据增强。接着构建一个包含多个卷积层、池化层和全连接层的卷积神经网络,最后编译模型并使用`fit_generator`方法对模型进行训练。最终,我们使用`save`方法将模型保存到本地。