写一份基于卷积神经网络动物分类的代码
时间: 2023-07-20 14:27:58 浏览: 84
好的,以下是一个基于卷积神经网络的动物分类代码,使用的是Python和TensorFlow框架:
```python
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 创建数据生成器
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
'train_dir',
target_size=(224, 224),
batch_size=32,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
'val_dir',
target_size=(224, 224),
batch_size=32,
class_mode='categorical')
# 创建模型
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
# 编译模型
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.RMSprop(lr=1e-4),
metrics=['acc'])
# 创建回调函数
checkpoint_callback = callbacks.ModelCheckpoint(
'model.h5', monitor='val_acc', save_best_only=True, mode='max')
earlystop_callback = callbacks.EarlyStopping(
monitor='val_acc', patience=5, mode='max')
# 训练模型
history = model.fit(
train_generator,
steps_per_epoch=100,
epochs=50,
validation_data=validation_generator,
validation_steps=50,
callbacks=[checkpoint_callback, earlystop_callback])
# 保存模型
model.save('model.h5')
```
以上代码中,我们使用了ImageDataGenerator来生成训练集和验证集,使用了Sequential模型和Conv2D、MaxPooling2D、Flatten、Dense等层来构建卷积神经网络模型,使用了RMSprop优化器、categorical_crossentropy损失函数和acc指标来编译模型,使用了ModelCheckpoint和EarlyStopping回调函数来保存模型和提前结束训练。
阅读全文