用决策树 对一个有10个标签2个类别 叫data.csv的数据集 用python实现信息增益 显示完整代码
时间: 2024-05-01 15:16:39 浏览: 78
以下是用Python实现信息增益的完整代码:
```python
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
# 读取数据集
data = pd.read_csv('data.csv')
# 将标签列作为分类目标
X = data.drop('label', axis=1)
y = data['label']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)
# 建立决策树模型
clf = DecisionTreeClassifier(criterion='entropy')
clf.fit(X_train, y_train)
# 预测测试集
y_pred = clf.predict(X_test)
# 计算准确率
accuracy = np.mean(y_pred == y_test)
print('准确率: {:.2f}%'.format(accuracy * 100))
# 输出决策树
from sklearn.tree import export_graphviz
import graphviz
dot_data = export_graphviz(
clf,
out_file=None,
feature_names=X.columns,
class_names=['0', '1'],
filled=True,
rounded=True,
special_characters=True
)
graph = graphviz.Source(dot_data)
graph.render("decision_tree")
graph
```
在这个代码中,`pandas`库被用于读取和处理数据集。`sklearn.tree`库提供了`DecisionTreeClassifier`类,用于建立决策树模型。`train_test_split`函数用于划分训练集和测试集。`numpy`库被用于计算准确率。`export_graphviz`函数用于输出决策树的DOT格式代码,`graphviz`库用于显示决策树。
阅读全文