基于tensorflow中的keras编写代码,文件夹rawdata下的两个子文件夹的所有图片 1. 编写数据加载函数; 2. 根据任务要求对数据集进行划分; 3. 数据增强; 4. 构建深度学习模型和损失函数; 5. 编写模型训练相关代码,完成模型训练 6. 使用可视化库 Matplotlib 对训练过程进行可视化展示,如损失值变化、准确率变化、召回率、ROC曲线等。 7. 将训练好的模型保存
时间: 2023-11-28 13:50:34 浏览: 88
以下是可能的代码实现,仅作参考:
```python
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
# 1. 数据加载函数
def load_data(data_dir, batch_size):
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')
train_datagen = ImageDataGenerator(rescale=1./255, rotation_range=20,
width_shift_range=0.1, height_shift_range=0.1,
shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(224, 224),
batch_size=batch_size, class_mode='binary')
val_datagen = ImageDataGenerator(rescale=1./255)
val_generator = val_datagen.flow_from_directory(val_dir, target_size=(224, 224),
batch_size=batch_size, class_mode='binary')
return train_generator, val_generator
# 2. 数据集划分
data_dir = 'rawdata'
batch_size = 32
train_generator, val_generator = load_data(data_dir, batch_size)
# 3. 数据增强
# 4. 构建模型和损失函数
def build_model():
model = keras.Sequential([
keras.layers.Conv2D(32, 3, activation='relu', input_shape=(224, 224, 3)),
keras.layers.MaxPooling2D(),
keras.layers.Conv2D(64, 3, activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Conv2D(128, 3, activation='relu'),
keras.layers.MaxPooling2D(),
keras.layers.Flatten(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
model = build_model()
model.summary()
# 5. 模型训练
epochs = 10
history = model.fit(train_generator, epochs=epochs, validation_data=val_generator)
# 6. 可视化展示
def plot_history(history):
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
x = range(epochs)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(x, acc, label='Training accuracy')
plt.plot(x, val_acc, label='Validation accuracy')
plt.title('Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(x, loss, label='Training loss')
plt.plot(x, val_loss, label='Validation loss')
plt.title('Loss')
plt.legend()
plt.show()
plot_history(history)
# 7. 模型保存
model.save('my_model.h5')
```
阅读全文