基于KNN的手写数字识别实现代码
时间: 2024-09-17 18:08:17 浏览: 57
基于KNN (K-Nearest Neighbors) 的手写数字识别通常涉及以下步骤:
1. **数据预处理**:首先,需要获取手写数字数据集,如MNIST。数据集通常包含灰度图像和相应的标签。然后对数据进行标准化或归一化,以便所有像素值都在0到1之间。
```python
from sklearn.datasets import load_digits
import numpy as np
# 加载数据
digits = load_digits()
images = digits.images.reshape((len(digits.images), -1)) / 16.
labels = digits.target
```
2. **分割数据**:将数据集划分为训练集和测试集,常用80%的数据用于训练,剩下的20%用于测试模型性能。
```python
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=0.2)
```
3. **计算距离**:KNN算法的核心是找到最近邻。这里可以使用欧氏距离或其他相似度度量。
4. **构建KNN模型**:使用`sklearn.neighbors.KNeighborsClassifier` 类,并设置`n_neighbors` 参数表示邻居的数量。
```python
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
```
5. **预测及评估**:利用模型对测试集进行预测,然后使用准确率等指标评估模型性能。
```python
y_pred = knn.predict(X_test)
accuracy = knn.score(X_test, y_test)
print("Accuracy:", accuracy)
```
阅读全文