LeNet-5训练自己的数据集的代码
时间: 2023-09-01 13:07:51 浏览: 102
使用LeNet-5进行字符识别:在MNIST数据集上训练的深度模型(LeNet-5)用于字符识别。-matlab开发
以下是使用LeNet-5模型训练自己的数据集的示例代码:
```python
import tensorflow as tf
import numpy as np
import cv2
import os
# 数据集路径
dataset_path = 'path/to/dataset'
# 定义LeNet-5模型
def LeNet5(input_shape, num_classes):
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), activation='relu', input_shape=input_shape),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=120, activation='relu'),
tf.keras.layers.Dense(units=84, activation='relu'),
tf.keras.layers.Dense(units=num_classes, activation='softmax')
])
return model
# 加载数据集
def load_dataset():
images = []
labels = []
# 遍历数据集文件夹
for root, dirs, files in os.walk(dataset_path):
for file in files:
# 读取图像
img = cv2.imread(os.path.join(root, file))
# 调整图像大小
img = cv2.resize(img, (32, 32))
# 将图像添加到列表中
images.append(img)
# 提取标签
label = int(os.path.basename(root))
# 将标签添加到列表中
labels.append(label)
# 转换为numpy数组
images = np.array(images)
labels = np.array(labels)
return images, labels
# 加载数据集
images, labels = load_dataset()
# 将标签转换为one-hot编码
labels = tf.keras.utils.to_categorical(labels)
# 将数据集分为训练集和测试集
train_images, test_images, train_labels, test_labels = tf.keras.utils.train_test_split(images, labels, test_size=0.2)
# 构建LeNet-5模型
model = LeNet5(input_shape=(32, 32, 3), num_classes=10)
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(train_images, train_labels, epochs=10, batch_size=32, validation_data=(test_images, test_labels))
# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)
# 保存模型
model.save('lenet5_model.h5')
```
在上面的代码中,我们首先定义了一个LeNet-5模型,然后使用load_dataset函数加载数据集,并将标签转换为one-hot编码。接着,我们将数据集分为训练集和测试集,并使用train_images和train_labels训练模型。最后,我们评估模型的准确性,并将模型保存到lenet5_model.h5文件中。
阅读全文