tensorflow2.0手写mnist识别
时间: 2023-04-27 12:06:10 浏览: 141
TensorFlow 2.是一个流行的深度学习框架,可以用于许多不同的任务,包括手写数字识别。手写数字识别是一个经典的机器学习问题,它涉及将手写数字图像分类为数字到9中的一个。
在TensorFlow 2.中,可以使用Keras API来构建和训练模型。以下是一个简单的手写数字识别模型的示例代码:
```python
import tensorflow as tf
from tensorflow import keras
# 加载MNIST数据集
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 将图像数据归一化到到1之间
train_images = train_images / 255.
test_images = test_images / 255.
# 定义模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(train_images, train_labels, epochs=5)
# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)
```
在这个例子中,我们首先加载MNIST数据集,然后将图像数据归一化到到1之间。接下来,我们定义一个简单的神经网络模型,该模型包括一个输入层、一个隐藏层和一个输出层。我们使用ReLU激活函数和softmax激活函数来激活隐藏层和输出层。最后,我们编译模型并训练它,使用Adam优化器和稀疏分类交叉熵损失函数。在训练完成后,我们评估模型的性能,输出测试准确率。
这只是一个简单的例子,你可以根据需要调整模型的结构和超参数来获得更好的性能。
阅读全文