利用adults数据集写一篇决策树代码,并画出ROC图
时间: 2024-03-19 10:42:08 浏览: 82
好的,下面是利用adult数据集进行决策树分类,并画出ROC图的Python代码:
```python
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
# 读取数据集
data = pd.read_csv('adult.csv')
# 数据预处理
data = data.dropna() # 删除缺失值
data['sex'] = data['sex'].apply(lambda x: 1 if x == 'Male' else 0) # 将性别转换为数值变量
data = pd.get_dummies(data, columns=['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race', 'native-country', 'income'], drop_first=True) # 将其他分类变量转换为数值变量
# 划分训练集和测试集
X = data.drop('income_>50K', axis=1)
y = data['income_>50K']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 建立决策树模型
clf = DecisionTreeClassifier(max_depth=5, random_state=42)
clf.fit(X_train, y_train)
# 对测试集进行预测
y_pred = clf.predict_proba(X_test)[:, 1]
# 绘制ROC曲线
fpr, tpr, thresholds = roc_curve(y_test, y_pred)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label='AUC = %0.2f' % roc_auc)
plt.plot([0, 1], [0, 1], 'r--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc='lower right')
plt.show()
```
在上述代码中,我们首先对数据集进行了预处理,包括删除缺失值、将性别转换为数值变量、将其他分类变量转换为数值变量。然后,我们将数据集划分为训练集和测试集,并建立了一个最大深度为5的决策树模型。最后,我们对测试集进行预测,并绘制了ROC曲线。
运行上述代码,即可得到绘制的ROC曲线图。
阅读全文