基于CNN手写数字识别代码及数据集
时间: 2023-09-16 13:12:00 浏览: 57
以下是基于CNN的手写数字识别代码及数据集:
首先,我们需要导入必要的库和数据集:
```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.datasets import mnist
# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 将数据集中的像素值缩放到 [0, 1] 范围内
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
# 将标签转换为 one-hot 编码
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
```
接下来,我们定义一个基于 CNN 的模型:
```python
model = keras.Sequential(
[
layers.Reshape(target_shape=(28, 28, 1), input_shape=(28, 28)),
layers.Conv2D(filters=32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(filters=64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dense(units=128, activation="relu"),
layers.Dropout(rate=0.5),
layers.Dense(units=10, activation="softmax"),
]
)
```
然后,我们进行模型的编译和训练:
```python
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=128, epochs=5, validation_split=0.1)
```
最后,我们测试模型的性能:
```python
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_accuracy}")
```
以上便是基于 CNN 的手写数字识别代码及数据集。