基于TensorFlow下的花朵识别代码
时间: 2024-03-23 15:37:32 浏览: 50
好的,以下是一个基于TensorFlow的花卉识别代码的示例:
```python
import tensorflow as tf
import numpy as np
import os
import cv2
# 定义花卉类别
classes = ['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
num_classes = len(classes)
# 定义输入图像的大小和通道数
img_size = 128
num_channels = 3
# 加载图像函数
def load_image(image_path):
img = cv2.imread(image_path)
img = cv2.resize(img, (img_size, img_size), cv2.INTER_LINEAR)
img = np.array(img, dtype=np.float32)
img = img.reshape(-1, img_size, img_size, num_channels)
return img
# 构建模型
def build_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), activation='relu', input_shape=(img_size, img_size, num_channels)),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(filters=64, kernel_size=(3, 3), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=128, activation='relu'),
tf.keras.layers.Dense(units=num_classes, activation='softmax')
])
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='categorical_crossentropy', metrics=['accuracy'])
return model
# 加载模型
model = build_model()
model.load_weights('flower_model.h5')
# 加载图像
image_path = 'test.jpg'
image = load_image(image_path)
# 进行预测
prediction = model.predict(image)
class_index = np.argmax(prediction)
class_name = classes[class_index]
print('Predicted class:', class_name)
```
在这个示例中,我们首先定义了花卉的类别,并指定了输入图像的大小和通道数。然后,我们定义了加载图像的函数和构建模型的函数。在加载模型后,我们加载了测试图像,并使用模型进行预测,输出预测结果。
需要注意的是,这个示例只是一个简单的示例,实际应用中可能需要更复杂的模型和更多的预处理步骤。同时,我们也需要足够的训练数据才能训练出准确的模型。
阅读全文