LeNet-5识别手写数字完整python代码,我的数据集放在D盘
时间: 2024-03-11 22:47:20 浏览: 115
以下是一个使用 LeNet-5 模型来识别手写数字的完整 Python 代码示例,假设数据集是以 MNIST 格式存储在 `D:/mnist` 文件夹中:
```python
import gzip
import numpy as np
import tensorflow as tf
def load_mnist(path, kind='train'):
"""Load MNIST data from `path`"""
labels_path = f'{path}/{kind}-labels-idx1-ubyte.gz'
images_path = f'{path}/{kind}-images-idx3-ubyte.gz'
with gzip.open(labels_path, 'rb') as lbpath:
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)
with gzip.open(images_path, 'rb') as imgpath:
images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784)
return images, labels
# 加载数据集
X_train, y_train = load_mnist('D:/mnist', kind='train')
X_test, y_test = load_mnist('D:/mnist', kind='test')
# 将数据集转换为浮点数类型,并归一化到 [0, 1]
X_train = X_train.astype('float32') / 255
X_test = X_test.astype('float32') / 255
# 将标签转换为 one-hot 编码
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
# 定义 LeNet-5 模型
model = tf.keras.Sequential([
tf.keras.layers.Reshape((28, 28, 1), input_shape=(784,)),
tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu'),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=120, activation='relu'),
tf.keras.layers.Dense(units=84, activation='relu'),
tf.keras.layers.Dense(units=10, activation='softmax')
])
# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# 训练模型
model.fit(X_train, y_train, epochs=10, batch_size=128, validation_data=(X_test, y_test))
# 评估模型
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=2)
print(f'Test accuracy: {test_acc}')
```
在这里,我们首先使用 `load_mnist()` 函数加载数据集,并将其转换为浮点数类型,并归一化到 [0, 1]。然后,我们将标签转换为 one-hot 编码。接下来,我们定义了一个包含两个卷积层和三个全连接层的 LeNet-5 模型,并使用 `compile()` 方法编译模型。然后,我们使用 `fit()` 方法训练模型,并使用 `evaluate()` 方法评估模型的性能。
阅读全文