基于tensorflow的最大均值差异 水果图像分类 完整代码
时间: 2023-06-16 14:02:16 浏览: 149
以下是基于 TensorFlow 的最大均值差异(Max-Mean Discrepancy,MMD)进行水果图像分类的完整代码:
```python
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
# 定义参数
batch_size = 32
learning_rate = 0.001
num_epochs = 50
num_classes = 3
# 加载数据集
def load_dataset():
train_dataset = tf.keras.preprocessing.image_dataset_from_directory(
'data/train',
image_size=(224, 224),
batch_size=batch_size,
label_mode='categorical')
val_dataset = tf.keras.preprocessing.image_dataset_from_directory(
'data/val',
image_size=(224, 224),
batch_size=batch_size,
label_mode='categorical')
return train_dataset, val_dataset
# 定义 MMD 损失函数
def compute_mmd(x, y, sigma=5.0):
x_kernel = tf.exp(-tf.square(tf.norm(x[:, tf.newaxis, :] - x[tf.newaxis, :, :], axis=2)) / (2 * sigma ** 2))
y_kernel = tf.exp(-tf.square(tf.norm(y[:, tf.newaxis, :] - y[tf.newaxis, :, :], axis=2)) / (2 * sigma ** 2))
xy_kernel = tf.exp(-tf.square(tf.norm(x[:, tf.newaxis, :] - y[tf.newaxis, :, :], axis=2)) / (2 * sigma ** 2))
mmd = tf.reduce_mean(x_kernel) + tf.reduce_mean(y_kernel) - 2 * tf.reduce_mean(xy_kernel)
return mmd
# 定义模型
def create_model():
base_model = tf.keras.applications.ResNet50(
input_shape=(224, 224, 3),
include_top=False,
weights='imagenet')
for layer in base_model.layers:
layer.trainable = False
inputs = tf.keras.Input(shape=(224, 224, 3))
x = tf.keras.applications.resnet50.preprocess_input(inputs)
x = base_model(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
model = tf.keras.Model(inputs, outputs)
return model
# 定义训练函数
def train(train_dataset, val_dataset):
model = create_model()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_acc = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
val_loss = tf.keras.metrics.Mean(name='val_loss')
val_acc = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')
# 定义 MMD 损失函数
def mmd_loss(y_true, y_pred):
features = model.layers[-2].output
features_train = features[:batch_size]
features_val = features[batch_size:]
mmd = compute_mmd(features_train, features_val)
return y_pred + mmd
model.compile(optimizer=optimizer, loss=mmd_loss, metrics=[train_acc, val_acc])
# 训练模型
history = model.fit(train_dataset, epochs=num_epochs, validation_data=val_dataset)
return history
# 加载数据集
train_dataset, val_dataset = load_dataset()
# 训练模型
history = train(train_dataset, val_dataset)
# 绘制训练过程中的准确率和损失
def plot_history(history):
acc = history.history['train_accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['train_loss']
val_loss = history.history['val_loss']
epochs_range = range(num_epochs)
plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
plot_history(history)
# 测试模型
def test_model(model):
test_dir = 'data/test'
fruit_names = sorted(os.listdir(test_dir))
fruit_dict = {}
for i, fruit_name in enumerate(fruit_names):
fruit_dict[i] = fruit_name
correct = 0
total = 0
for fruit_id in range(len(fruit_names)):
fruit_name = fruit_dict[fruit_id]
fruit_dir = os.path.join(test_dir, fruit_name)
for filename in os.listdir(fruit_dir):
img = Image.open(os.path.join(fruit_dir, filename))
img = img.resize((224, 224))
img = np.array(img) / 255.0
img = img[np.newaxis, ...]
pred = model.predict(img)
pred_id = np.argmax(pred)
if pred_id == fruit_id:
correct += 1
total += 1
accuracy = correct / total
print('Test Accuracy:', accuracy)
test_model(model)
```
该代码中,我们首先加载数据集,然后定义了 MMD 损失函数和模型。在训练函数中,我们编译模型并使用 MMD 损失函数进行训练。最后,我们绘制了训练过程中的准确率和损失,并测试了模型在测试集上的准确率。
阅读全文