使用knn对sklearn 手写集 进行分类
时间: 2023-11-24 09:06:10 浏览: 58
好的,首先需要导入必要的库,包括sklearn、numpy和matplotlib:
```python
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import matplotlib.pyplot as plt
```
然后,我们可以加载手写数字数据集,并将其划分为训练集和测试集:
```python
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)
```
接下来,我们可以定义一个KNN分类器,并使用训练集进行拟合:
```python
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
```
最后,我们可以使用测试集来评估模型的性能,并计算其准确率:
```python
accuracy = knn.score(X_test, y_test)
print("Accuracy:", accuracy)
```
完整代码如下:
```python
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
import matplotlib.pyplot as plt
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)
accuracy = knn.score(X_test, y_test)
print("Accuracy:", accuracy)
```
输出结果:
```
Accuracy: 0.9777777777777777
```
这意味着我们的模型能够正确分类测试集中约98%的手写数字图像。
阅读全文