基于LeNet5或AlexNet的动物图像分类代码
时间: 2023-08-08 14:12:49 浏览: 107
以下是基于AlexNet的动物图像分类代码示例:
```python
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 数据集路径
train_dir = "path/to/train/directory"
validation_dir = "path/to/validation/directory"
# 定义图像生成器
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
validation_datagen = ImageDataGenerator(rescale=1./255)
# 生成训练数据
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(224, 224),
batch_size=32,
class_mode='categorical'
)
# 生成验证数据
validation_generator = validation_datagen.flow_from_directory(
validation_dir,
target_size=(224, 224),
batch_size=32,
class_mode='categorical'
)
# 定义模型
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(96, (11,11), strides=(4,4), activation='relu', input_shape=(224,224,3)),
tf.keras.layers.MaxPooling2D((3,3), strides=(2,2)),
tf.keras.layers.Conv2D(256, (5,5), activation='relu'),
tf.keras.layers.MaxPooling2D((3,3), strides=(2,2)),
tf.keras.layers.Conv2D(384, (3,3), activation='relu'),
tf.keras.layers.Conv2D(384, (3,3), activation='relu'),
tf.keras.layers.Conv2D(256, (3,3), activation='relu'),
tf.keras.layers.MaxPooling2D((3,3), strides=(2,2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4096, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(4096, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(4, activation='softmax')
])
# 编译模型
model.compile(
loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.SGD(lr=0.01),
metrics=['accuracy']
)
# 训练模型
model.fit(
train_generator,
steps_per_epoch=train_generator.n // train_generator.batch_size,
epochs=50,
validation_data=validation_generator,
validation_steps=validation_generator.n // validation_generator.batch_size
)
```
注:上述代码中,我们使用的是在动物图像分类问题中表现较好的AlexNet模型,但是由于数据集和训练环境的不同,需要根据实际情况进行调整和优化。
阅读全文