mnist手写数字识别python
时间: 2023-09-14 09:04:40 浏览: 106
MNIST.rar_MNIST手写数字体识别_mnist 手写识别_python机器学习_tensorflow_手写识别 MNI
5星 · 资源好评率100%
MNIST是一个常用的手写数字数据集,包含了60000个训练样本和10000个测试样本。每个样本都是一张28x28像素的灰度图像,标签为0-9中的一个数字。在Python中进行MNIST手写数字识别的步骤如下:
1. 导入必要的库
```python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
```
2. 加载MNIST数据集
```python
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
```
3. 数据预处理
```python
train_images = train_images / 255.0
test_images = test_images / 255.0
```
4. 构建模型
```python
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10)
])
```
5. 编译模型
```python
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
```
6. 训练模型
```python
model.fit(train_images, train_labels, epochs=10)
```
7. 评估模型
```python
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
print('\nTest accuracy:', test_acc)
```
8. 使用模型进行预测
```python
probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])
predictions = probability_model.predict(test_images)
```
9. 可视化预测结果
```python
def plot_image(i, predictions_array, true_label, img):
predictions_array, true_label, img = predictions_array[i], true_label[i], img[i]
plt.grid(False)
plt.xticks([])
plt.yticks([])
plt.imshow(img, cmap=plt.cm.binary)
predicted_label = np.argmax(predictions_array)
if predicted_label == true_label:
color = 'blue'
else:
color = 'red'
plt.xlabel("{} {:2.0f}% ({})".format(predicted_label,
100 * np.max(predictions_array),
true_label),
color=color)
def plot_value_array(i, predictions_array, true_label):
predictions_array, true_label = predictions_array[i], true_label[i]
plt.grid(False)
plt.xticks(range(10))
plt.yticks([])
thisplot = plt.bar(range(10), predictions_array, color="#777777")
plt.ylim([0, 1])
predicted_label = np.argmax(predictions_array)
thisplot[predicted_label].set_color('red')
thisplot[true_label].set_color('blue')
num_rows = 5
num_cols = 3
num_images = num_rows * num_cols
plt.figure(figsize=(2 * 2 * num_cols, 2 * num_rows))
for i in range(num_images):
plt.subplot(num_rows, 2 * num_cols, 2 * i + 1)
plot_image(i, predictions, test_labels, test_images)
plt.subplot(num_rows, 2 * num_cols, 2 * i + 2)
plot_value_array(i, predictions, test_labels)
plt.tight_layout()
plt.show()
```
以上就是使用Python进行MNIST手写数字识别的全部步骤。
阅读全文