knn算法实现鸢尾花分类准确率可视化代码
时间: 2024-09-14 13:14:34 浏览: 36
K近邻(K-Nearest Neighbors, KNN)算法是一种基于实例的学习方法,用于解决分类问题。这里是一个简单的Python示例,使用sklearn库实现鸢尾花数据集的KNN分类,并可视化准确率:
首先,确保已经安装了`sklearn`, `matplotlib`和`seaborn`库。如果没有,可以使用以下命令安装:
```bash
pip install scikit-learn matplotlib seaborn
```
然后,我们可以编写代码:
```python
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 定义k值范围
ks = range(1, 31)
accuracies = []
# 对每个k值计算准确率并存储
for k in ks:
knn = KNeighborsClassifier(n_neighbors=k)
knn.fit(X_train, y_train)
predictions = knn.predict(X_test)
accuracies.append(accuracy_score(y_test, predictions))
# 绘制准确率随k值变化的折线图
plt.plot(ks, accuracies, 'b-o', label='Accuracy')
plt.xlabel('Number of neighbors (k)')
plt.ylabel('Accuracy')
plt.title('KNN Accuracy vs K for Iris dataset')
plt.legend(loc='upper right')
plt.show()
```
这个代码会显示一个图形,横坐标是k值,纵坐标是对应的分类准确率。通过观察图形,您可以找到最优的k值,使得模型性能最好。