通过构建决策树(采用Gini指标)对随机生成(通过np.random.randint方法)的数字进行分类,自变量X为100x4的矩阵,随机生成的数字大于10,因变量Y为大于2的100x1矩阵。树的最大深度限制为3层,训练完成之后将树可视化显示。通过代码演示
时间: 2024-10-10 07:12:08 浏览: 38
比对Gini和Twoing-CART决策树
在Python中,我们可以使用scikit-learn库中的DecisionTreeClassifier来构建决策树,Gini指数作为分裂依据。首先,我们需要导入所需的库,然后生成数据,接着构建模型、训练并可视化决策树。以下是示例代码:
```python
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
# 生成随机数据
np.random.seed(0) # 设置随机种子以便于复现结果
X, y = make_classification(n_samples=100, n_features=4, n_informative=2,
n_classes=2, class_sep=3, random_state=42)
# 只保留大于10的数值
X = X[y > 10]
y = y[y > 2]
# 构建决策树模型,最大深度设为3
clf = DecisionTreeClassifier(criterion='gini', max_depth=3)
# 训练模型
clf.fit(X, y)
# 将数据转换为DataFrame方便绘制
df = pd.DataFrame(data=np.c_[X.ravel(), y], columns=[f'Feature_{i}' for i in range(1, 5)] + ['Label'])
# 定义画图函数
def plot_tree(clf, df, feature_names):
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data,
filled=True, rounded=True,
special_characters=True, feature_names=feature_names,
class_names=['Class_1', 'Class_2'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
plt.figure(figsize=(12, 8))
_ = nx.nx_pydot.from_pydot(graph.create_dot()).draw()
# 可视化决策树
plot_tree(clf, df, [f'Feature_{i}' for i in range(1, 5)])
plt.show()
```
在这个例子中,我们首先创建了一个简单的二分类问题,然后训练了决策树模型,并利用`export_graphviz`函数将其转换为DOT语言进行可视化。最后通过`pydotplus`库将DOT文件展示出来。
阅读全文