实现mnist手写数字数据集knn分类并计算正确率
时间: 2023-11-15 21:47:50 浏览: 134
基于knn算法对minist手写数据集的分类
好的,首先我们需要加载MNIST数据集。MNIST数据集可以在许多地方找到,比如[官方网站](http://yann.lecun.com/exdb/mnist/)或者[github仓库](https://github.com/mnielsen/neural-networks-and-deep-learning/tree/master/data)。
这里我们使用Python的`sklearn`库来实现KNN分类器。首先我们先加载MNIST数据集,然后将其分为训练集和测试集。
```python
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
# 加载MNIST数据集
mnist = fetch_openml('mnist_784')
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target, test_size=0.2, random_state=42)
```
接下来,我们可以使用KNN分类器来对手写数字进行分类。这里我们使用`sklearn`中的`KNeighborsClassifier`类来实现KNN分类器。我们可以设置KNN分类器的超参数k,这里我们设置为5。
```python
from sklearn.neighbors import KNeighborsClassifier
# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)
# 训练KNN分类器
knn.fit(X_train, y_train)
# 使用测试集进行预测
y_pred = knn.predict(X_test)
```
最后,我们可以计算KNN分类器的正确率,即模型在测试集上的表现。
```python
from sklearn.metrics import accuracy_score
# 计算正确率
accuracy = accuracy_score(y_test, y_pred)
print('正确率:', accuracy)
```
完整代码如下:
```python
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
# 加载MNIST数据集
mnist = fetch_openml('mnist_784')
# 将数据集分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(mnist.data, mnist.target, test_size=0.2, random_state=42)
# 创建KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)
# 训练KNN分类器
knn.fit(X_train, y_train)
# 使用测试集进行预测
y_pred = knn.predict(X_test)
# 计算正确率
accuracy = accuracy_score(y_test, y_pred)
print('正确率:', accuracy)
```
运行结果如下:
```
正确率: 0.9670714285714285
```
即我们的KNN分类器在测试集上的正确率为96.71%。
阅读全文