请给出基于CNN搭建一个水果分类器的代码
时间: 2023-10-29 15:05:56 浏览: 104
好的,下面是基于TensorFlow深度学习框架搭建一个简单的水果分类器的代码:
```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
# 加载数据集
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
"fruits/train",
validation_split=0.2,
subset="training",
seed=123,
image_size=(64, 64),
batch_size=32,
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
"fruits/train",
validation_split=0.2,
subset="validation",
seed=123,
image_size=(64, 64),
batch_size=32,
)
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
"fruits/test",
seed=123,
image_size=(64, 64),
batch_size=32,
)
# 数据归一化
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)
train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))
test_ds = test_ds.map(lambda x, y: (normalization_layer(x), y))
# 搭建CNN模型
model = keras.Sequential([
layers.experimental.preprocessing.Rescaling(1./255, input_shape=(64, 64, 3)),
layers.Conv2D(32, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(128, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(5)
])
# 编译模型
model.compile(
optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# 训练模型
model.fit(
train_ds,
validation_data=val_ds,
epochs=10
)
# 评估模型
model.evaluate(test_ds)
# 保存模型
model.save('fruit_classifier.h5')
```
这段代码实现了对水果数据集进行训练、验证和测试的功能,并将训练好的模型保存在`fruit_classifier.h5`文件中。需要注意的是,这只是一个简单的示例代码,实际应用中需要根据具体情况进行调整和优化。
阅读全文