mnist手写数字数据集knn完整代码
时间: 2024-07-01 20:00:37 浏览: 170
MNIST手写数字数据集是一个常用的数据集,用于图像分类任务,特别是机器学习和深度学习的入门教程。K-近邻(K-Nearest Neighbors, KNN)算法在这种情况下可以用来做基本的分类。以下是一个使用Python和scikit-learn库实现KNN分类MNIST数据集的简单示例代码:
```python
# 导入所需库
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report
# 加载MNIST数据集
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist['data'], mnist['target']
# 数据预处理
X = X.astype('float32') / 255.0 # 将像素值归一化到0-1范围
y = y.astype('int64')
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 标准化数据
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)
# 训练模型
knn.fit(X_train, y_train)
# 预测测试集
y_pred = knn.predict(X_test)
# 评估模型性能
print("Classification Report:")
print(classification_report(y_test, y_pred))
# 如果你需要运行完整代码,请确保安装了所需的库(如scikit-learn)
```
在这个代码中,我们首先加载数据并预处理,然后划分训练集和测试集。接下来,我们创建一个KNN分类器,设置邻居数为5,然后训练模型。最后,我们用测试集进行预测,并打印出分类报告来评估模型的性能。
阅读全文