xgboost利用鸢尾花数据集训练并绘制roc曲线
时间: 2023-08-07 21:01:03 浏览: 248
这是鸢尾花数据集
5星 · 资源好评率100%
首先,鸢尾花数据集是机器学习领域一个经典的数据集,包含了三个不同种类的鸢尾花的样本,每个样本包含了四个特征:萼片长度、萼片宽度、花瓣长度和花瓣宽度。xgboost是一种高效的机器学习算法,擅长处理分类和回归问题。
要使用xgboost对鸢尾花数据集进行训练并绘制ROC曲线,我们需要执行以下步骤:
首先,导入鸢尾花数据集和xgboost库。
import xgboost as xgb
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
然后,加载鸢尾花数据集。
iris = load_iris()
X = iris.data
y = iris.target
接下来,将数据集分为训练集和测试集。
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
然后,使用xgboost模型训练。
model = xgb.XGBClassifier()
model.fit(X_train, y_train)
接着,进行预测并计算ROC曲线的关键指标-真阳性率和假阳性率。
y_pred = model.predict_proba(X_test)[:,1]
roc_auc = roc_auc_score(y_test, y_pred)
最后,绘制ROC曲线。
fpr, tpr, thresholds = roc_curve(y_test, y_pred)
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], 'r--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.show()
以上就是使用xgboost对鸢尾花数据集进行训练并绘制ROC曲线的步骤。通过观察ROC曲线的形状和计算得到的曲线下面积(AUC),我们可以评估xgboost模型的性能和质量。
阅读全文