写一个python的代码,要求:MNIST手写数字识别 数据集下载:https://yann,lecun,com/exdb/mnist/要求:训练多个Epochs;显示测试集的识别精度;画出训练曲线,展示至少4个识别结果
时间: 2024-12-19 22:24:26 浏览: 28
为了实现这个需求,我们首先需要安装必要的库,如TensorFlow、Keras、matplotlib以及numpy。然后我们将加载MNIST数据集,编写模型并训练它,最后评估性能并绘制训练曲线。这是一个基本的示例代码:
```python
# 导入所需的库
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
# 下载MNIST数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 归一化图像到0-1范围
train_images, test_images = train_images / 255.0, test_images / 255.0
# 将标签转换为one-hot编码
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
# 定义模型
model = Sequential([
Flatten(input_shape=(28, 28)),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
epochs = 4
history = model.fit(train_images, train_labels, epochs=epochs, validation_split=0.2)
# 绘制训练曲线
plt.figure(figsize=(12, 6))
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
# 选择四个样本进行预测展示
sample_indices = [0, 10, 20, 30] # 可视化前几行测试数据
fig, axs = plt.subplots(2, 2, figsize=(10, 10))
for i, ax in enumerate(axs.flatten()):
img_index = sample_indices[i]
img = test_images[img_index].reshape(28, 28)
pred_label = model.predict_classes(img.reshape(1, -1))[0]
true_label = np.argmax(test_labels[img_index])
ax.imshow(img, cmap='gray')
ax.set_title(f"Predicted: {pred_label}, True: {true_label}")
ax.axis('off')
plt.tight_layout()
plt.show()
阅读全文