设计一个卷积神经网络识别MNIST手写数字并进行测试
时间: 2023-08-11 10:05:53 浏览: 119
好的,以下是一个简单的卷积神经网络(CNN)模型,用于识别 MNIST 手写数字:
```python
import tensorflow as tf
from tensorflow.keras import layers
# 加载 MNIST 数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 将像素值归一化到 0~1 之间
x_train, x_test = x_train / 255.0, x_test / 255.0
# 添加一个维度,将数据格式转换为 (样本数, 高度, 宽度, 通道数)
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
# 创建卷积神经网络模型
model = tf.keras.Sequential([
layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=5)
# 测试模型
model.evaluate(x_test, y_test)
```
这个模型包含两个卷积层和两个池化层,以及两个全连接层,其中最后一个输出层使用 softmax 激活函数进行多分类。在训练期间,我们使用 Adam 优化器和交叉熵损失函数进行优化,同时追踪模型的准确率指标。在测试期间,我们使用测试集进行评估,并打印模型在测试集上的准确率。你可以使用上述代码来训练和测试这个 CNN 模型来识别 MNIST 手写数字。
阅读全文