基于bp神经网络关于数字手写体的识别python代码
时间: 2024-10-21 17:01:31 浏览: 30
基于BP神经网络的手写数字识别通常涉及到图像预处理、特征提取、训练神经网络以及预测等步骤。这里是一个简单的Python代码示例,使用了`numpy`和`matplotlib`库,以及经典的MNIST数据集:
```python
# 导入所需的库
import numpy as np
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.utils import to_categorical
# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据预处理
x_train = x_train.reshape(-1, 28 * 28) / 255.0 # 归一化像素值
x_test = x_test.reshape(-1, 28 * 28) / 255.0
y_train = to_categorical(y_train, num_classes=10)
y_test = to_categorical(y_test, num_classes=10)
# 构建BP神经网络模型
model = Sequential()
model.add(Dense(128, activation='relu', input_shape=(784,))) # 输入层到隐藏层
model.add(Dense(64, activation='relu')) # 隐藏层
model.add(Dense(10, activation='softmax')) # 输出层,10个类别
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))
# 测试模型
loss, accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {accuracy}")
# 示例预测
prediction_index = np.argmax(model.predict(x_test[0].reshape(1, -1)))
predicted_digit = prediction_index + 1
print(f"Predicted digit: {predicted_digit}")
阅读全文