鸢尾花数据是一个用于机器学习的经典数据集,用DT算法进行分类,并将其分类效果,即AUC曲线做出来,将代码写出来
时间: 2024-11-18 07:18:36 浏览: 13
用于进行分类算法的8个数据集
鸢尾花数据集(Iris dataset)通常用于监督学习的入门教程,特别是决策树(Decision Tree, DT)算法的演示。这个数据集包含三种类型的鸢尾花(Setosa、Versicolour 和 Virginica),每种都有四个特征:萼片长度、萼片宽度、花瓣长度和花瓣宽度。
要在Python中使用`scikit-learn`库来处理鸢尾花数据集,包括训练决策树模型并生成AUC(Area Under Curve)曲线,你可以参考以下步骤:
```python
# 导入所需的库
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import roc_auc_score, roc_curve
import matplotlib.pyplot as plt
# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建决策树分类器
clf = DecisionTreeClassifier()
# 训练模型
clf.fit(X_train, y_train)
# 预测并计算概率
y_scores = clf.predict_proba(X_test)[:, 1] # 取第1列,假设目标变量是二分类问题
# 计算AUC分数
auc = roc_auc_score(y_test, y_scores)
print(f"AUC score: {auc}")
# 画出ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_scores)
plt.plot(fpr, tpr, label=f"ROC curve (area = {auc:.2f})")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC)")
plt.legend(loc="lower right")
plt.show()
```
在这个例子中,我们首先加载数据,然后划分训练集和测试集,创建并训练决策树模型。接着,预测测试集的结果并得到二分类的概率,计算AUC值。最后,通过`roc_curve`函数绘制ROC曲线并展示结果。
阅读全文