花卉识别 tensorflow
时间: 2024-12-27 19:18:30 浏览: 4
### 花卉图像分类与识别
#### 加载和预处理数据集
为了构建一个有效的花卉图像分类模型,`tf.keras.utils.image_dataset_from_directory` 提供了一个简便的方法来从目录结构中读取图像并创建 `TensorFlow Dataset API` 的输入管道。此函数能够自动解析子文件夹中的类别标签,并支持多种参数配置以适应不同的需求[^1]。
```python
import tensorflow as tf
from pathlib import Path
data_dir = Path('path/to/flower_photos') # 替换为实际路径
batch_size = 32
img_height = 180
img_width = 180
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2, # 划分验证集比例
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
```
#### 构建卷积神经网络(CNN)模型
对于图像分类任务而言,CNN 是一种非常成功的架构选择。下面展示的是一个简单的 CNN 模型定义过程:
```python
num_classes = 5 # 对应五种不同类型的花
model = Sequential([
layers.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(num_classes)
])
```
#### 编译与训练模型
完成上述准备工作之后就可以编译模型并开始训练了。这里采用 Adam 优化算法以及稀疏交叉熵损失函数来进行监督学习。
```python
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
epochs = 10
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
```
通过以上步骤可以成功搭建起一套基于 TensorFlow 平台上的花卉图像分类系统[^3]。
阅读全文