python手写数字识别显示九张图片代码
时间: 2023-09-03 16:05:52 浏览: 68
这里提供一个使用Python和MNIST数据集手写数字识别的示例代码,可以显示九张随机选择的图片和它们的预测结果:
```python
import random
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
# 加载MNIST数据集
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 构建模型
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(train_images, train_labels, epochs=5)
# 随机选择9张图片进行预测
sample_indexes = random.sample(range(test_images.shape[0]), 9)
sample_images = test_images[sample_indexes]
sample_labels = test_labels[sample_indexes]
predictions = model.predict(sample_images)
# 显示图片和预测结果
plt.figure(figsize=(10,10))
for i in range(9):
plt.subplot(3, 3, i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(sample_images[i], cmap=plt.cm.binary)
predicted_label = np.argmax(predictions[i])
true_label = sample_labels[i]
if predicted_label == true_label:
color = 'green'
else:
color = 'red'
plt.xlabel("{} ({})".format(predicted_label, true_label), color=color)
plt.show()
```
这段代码使用了Keras库构建神经网络模型,并使用MNIST数据集进行训练和预测。在训练完成后,从测试集中随机选择9张图片进行预测,并将结果显示出来。
阅读全文