tensorflow2.x 玉米病害 分类 代码
时间: 2023-08-01 09:03:09 浏览: 161
tensorflow二分类源码.zip
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# 设置随机种子,保证结果可复现
tf.random.set_seed(42)
# 定义数据生成器
datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)
# 设置数据目录
train_dir = 'train_directory'
test_dir = 'test_directory'
# 从数据目录加载数据
train_generator = datagen.flow_from_directory(
train_dir, target_size=(150, 150), batch_size=32, class_mode='categorical', subset='training'
)
validation_generator = datagen.flow_from_directory(
train_dir, target_size=(150, 150), batch_size=32, class_mode='categorical', subset='validation'
)
# 定义模型
model = keras.Sequential([
keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)),
keras.layers.MaxPooling2D(2, 2),
keras.layers.Conv2D(64, (3, 3), activation='relu'),
keras.layers.MaxPooling2D(2, 2),
keras.layers.Conv2D(128, (3, 3), activation='relu'),
keras.layers.MaxPooling2D(2, 2),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(4, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(train_generator, epochs=10, validation_data=validation_generator)
# 评估模型
test_generator = datagen.flow_from_directory(
test_dir, target_size=(150, 150), batch_size=32, class_mode='categorical'
)
results = model.evaluate(test_generator)
print('测试集准确率:', results[1])
阅读全文