import numpy as np import matplotlib.pyplot as plt from matplotlib.colors import ListedColormap from sklearn.neighbors import KNeighborsClassifier from sklearn.datasets import load_iris iris=load_iris() X=iris.data[:,:2] Y=iris.target print(iris.feature_names) cmap_light=ListedColormap(['#FFAAAA','#AAFFAA','#AAAAFF']) cmap_bold=ListedColormap(['#FF0000','#00FF00','#0000FF']) clf=KNeighborsClassifier(n_neighbors=10,weights='uniform') clf.fit(X,Y) #画出决策边界 x_min,x_max=X[:,0].min()-1,X[:,0].max()+1 y_min,y_max=X[:,1].min()-1,X[:,1].max()+1 xx,yy=np.meshgrid(np.arange(x_min,x_max,0.02), np.arange(y_min,y_max,0.02)) Z=clf.predict(np.c_[xx.ravel(),yy.ravel()]) reshape(xx.shape) plt.figure() plt.pcolormesh(xx,yy,Z,cmap=cmap_light) #绘制预测结果图 plt.scatter(X[:,0],X[:,1],c=Y,cmap=cmap_bold) plt.xlim(xx,min(),xx.max()) plt.ylim(yy.min(),yy.max()) plt.title('3_Class(k=10,weights=uniform)') plt.show()
时间: 2024-04-25 18:27:42 浏览: 98
这段代码的作用是使用K近邻算法对鸢尾花数据集进行分类,并绘制出分类的决策边界图。具体来说,代码首先加载鸢尾花数据集并提取其中的前两个特征作为模型输入,然后使用KNeighborsClassifier类创建一个K近邻分类器,并使用fit方法拟合数据集。接着,代码使用meshgrid函数生成一组网格点,并将其作为输入数据进行分类预测,最后将分类结果绘制成彩色的决策边界图。同时,代码还会将数据集中的样本点用散点图表示出来。
阅读全文