from sklearn.datasets import load_iris from sklearn. model_selection import train_test_split from sklearn.metrics import classification_report from sklearn. neighbors import KNeighborsClassifier from sklearn. metrics import roc_curve, auc import matplotlib.pyplot as plt from sklearn. metrics import confusion_matrix import seaborn as sns import scikitplot as skplt #加载数据集 iris = load_iris() data = iris['data'] label = iris['target'] #数据集的划分 x_train,x_test,y_train,y_test = train_test_split(data,label,test_size=0.3) print(x_train) #模型构建 model = KNeighborsClassifier(n_neighbors=5) model.fit(x_train,y_train) #模型评估 #(1)精确率,召回率,F1分数,准确率(宏平均和微平均) predict = model. predict(x_test) result = classification_report(y_test,predict) print(result) # (2) 混淆矩阵 confusion_matrix = confusion_matrix(y_test, predict) print('混淆矩阵:', confusion_matrix) sns.set(font_scale=1) sns.heatmap(confusion_matrix, annot=True, annot_kws={"size", 16}, cmap=plt.cm.Blues) plt.title('Confusion Matrix') plt.ylabel('True label' ) plt.xlabel('Predicted label') plt.savefig('Confusion matrix. pdf') plt.show() #(3)ROC曲线 Y_pred_prob = model. predict_proba(x_test) plt.figure(figsize= (7,7)) ax= plt. subplot() skplt.metrics.plot_roc_curve(y_test,Y_pred_prob,ax= ax) ax.set_xlabel('False Positive Rate', fontsize = 20) ax.set_ylabel('True Positive Rate ',fontsize = 20) ax.set_title('ROC Areas ',fontsize = 20) plt.xlim((0, 1)) plt.ylim((0, 1)) plt.xticks(fontsize = 18) plt.yticks(fontsize = 18) plt.legend(fontsize =18) plt.savefig(' ROC.pdf') plt.show( ) #(4)P_R曲线 from sklearn.metrics import precision_recall_curve precision, recall, _ =precision_recall_curve(y_test) plt.fill_between(recall, precision,color='b') plt.xlabel('Recall') plt.ylabel('Precision') plt.ylim([0.0, 1.0]) plt.xlim([0.0, 1.0]) plt.plot(recall, precision) plt.title("Precision-Recall") plt.show()
时间: 2023-06-06 19:04:20 浏览: 168
SVM.zip_SVM Iris分类_UCI SVM_iris_svm预测_uci svm
这是一段 Python 代码,主要包括从Scikit-learn库中导入了一些数据集、分类器、评估指标和可视化工具的模块。其中load_iris用于加载鸢尾花数据集,train_test_split用于将数据集分为训练集和测试集,classification_report用于评估分类器的性能,KNeighborsClassifier是K近邻分类器,roc_curve和auc用于评估二元分类器的性能,matplotlib.pyplot和seaborn是可视化工具,confusion_matrix和scikitplot是评估和可视化混淆矩阵的工具。
阅读全文