基于bp神经网络的手写数字识别 python代码
时间: 2023-09-11 19:01:55 浏览: 117
基于BP神经网络的手写数字识别
基于BP神经网络的手写数字识别是一种常见的机器学习任务,可以使用Python编写实现。下面是一个简单的代码示例:
1. 数据准备:
- 首先,需要准备好用于训练的手写数字图片数据集。可以使用MNIST数据集,其中包含了60000个训练样本和10000个测试样本。
- 然后,将图片数据集转换为模型可接受的输入格式。对于MNIST数据集,可以将28x28的图片展开为1维向量,并将像素值进行归一化处理。
2. 导入相关库:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report
```
3. 加载数据集:
```python
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, cache=True)
X = StandardScaler().fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
```
4. 定义BP神经网络模型:
```python
class BPNeuralNetwork():
def __init__(self, input_size, hidden_size, output_size):
self.W1 = np.random.randn(input_size, hidden_size)
self.W2 = np.random.randn(hidden_size, output_size)
def forward(self, X):
self.z2 = np.dot(X, self.W1)
self.a2 = self.sigmoid(self.z2)
self.z3 = np.dot(self.a2, self.W2)
self.a3 = self.sigmoid(self.z3)
return self.a3
def sigmoid(self, z):
return 1 / (1 + np.exp(-z))
```
5. 定义训练函数:
```python
def train(model, X, y, learning_rate, num_epochs):
for epoch in range(num_epochs):
# 前向传播
output = model.forward(X)
# 反向传播
delta3 = output - y
delta2 = np.dot(delta3, model.W2.T) * model.a2 * (1 - model.a2)
# 权重更新
dW2 = np.dot(model.a2.T, delta3)
dW1 = np.dot(X.T, delta2)
model.W2 -= learning_rate * dW2
model.W1 -= learning_rate * dW1
# 打印训练过程
if (epoch + 1) % 10 == 0:
loss = np.mean(np.square(output - y))
print(f'Epoch {epoch+1}/{num_epochs}, loss: {loss:.4f}')
```
6. 定义预测函数:
```python
def predict(model, X):
output = model.forward(X)
return np.argmax(output, axis=1)
```
7. 训练模型并进行预测:
```python
model = BPNeuralNetwork(input_size=X_train.shape[1], hidden_size=100, output_size=10)
train(model, X_train, y_train, learning_rate=0.01, num_epochs=100)
y_pred = predict(model, X_test)
```
8. 评估模型性能:
```python
print(classification_report(y_test, y_pred))
```
这只是一个简单的示例,真实的BP神经网络模型可能需要更复杂的结构和参数调优。希望这个简单的代码示例能帮到您!
阅读全文