手写数字识别python
时间: 2023-09-14 21:10:13 浏览: 70
手写数字识别是一种常见的机器学习任务,可以通过训练一个神经网络来实现。以下是一个使用Python实现手写数字识别的例子。
1. 数据集准备
首先,我们需要准备MNIST数据集,它包含了很多手写数字的图片和对应的标签。可以使用以下代码下载并加载数据集:
```python
import numpy as np
from keras.datasets import mnist
# 加载MNIST数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 调整数据格式
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
# 将标签转换为one-hot编码
from keras.utils import to_categorical
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
```
2. 构建模型
接下来,我们需要构建一个神经网络模型。这里我们使用Keras库来构建模型。以下是一个简单的两层全连接网络:
```python
from keras import models
from keras import layers
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))
```
3. 编译模型
在训练模型之前,我们需要编译模型。这里我们选择交叉熵作为损失函数,使用随机梯度下降法来优化模型:
```python
network.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
```
4. 训练模型
现在,我们可以使用训练数据来训练模型了:
```python
network.fit(train_images, train_labels, epochs=5, batch_size=128)
```
5. 评估模型
最后,我们可以使用测试数据来评估模型的准确率:
```python
test_loss, test_acc = network.evaluate(test_images, test_labels)
print('test_acc:', test_acc)
```
阅读全文