用python训练mnist数据集
时间: 2024-10-06 07:00:18 浏览: 31
用python 将mnist 数据集转化为图片
在Python中,我们可以使用深度学习库如TensorFlow或PyTorch来训练MNIST手写数字识别数据集。MNIST是一个非常基础但经典的图像分类任务,包含60,000张训练图片和10,000张测试图片,每张图片都是28x28像素灰度图像。
以下是使用Keras(一个高级API,可以简化神经网络的构建)训练MNIST的基本步骤:
```python
# 导入必要的库
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
# 加载并预处理数据
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images / 255.0 # 归一化到[0,1]范围
test_images = test_images / 255.0
train_images = train_images.reshape(-1, 28 * 28) # 将二维数组展平成一维
test_images = test_images.reshape(-1, 28 * 28)
# 构建模型
model = Sequential([
Flatten(input_shape=(28, 28)), # 展开输入层
Dense(128, activation='relu'), # 隐藏层
Dense(10, activation='softmax') # 输出层,10个节点对应0-9的类别
])
# 编译模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(train_images, train_labels, epochs=10, validation_split=0.2) # 使用一部分数据做验证
# 评估模型性能
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")
#
阅读全文