cart算法代码python
时间: 2023-11-29 18:46:47 浏览: 144
以下是使用Python实现CART算法的代码示例:
```python
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
# 加载数据集
iris = load_iris()
X = iris.data[:, [2, 3]]
y = iris.target
# 训练模型
tree = DecisionTreeClassifier(criterion='gini', max_depth=4, random_state=1)
tree.fit(X, y)
# 可视化决策树
from pydotplus import graph_from_dot_data
from sklearn.tree import export_graphviz
dot_data = export_graphviz(tree, filled=True, rounded=True, class_names=['Setosa', 'Versicolor', 'Virginica'],
feature_names=['petal length', 'petal width'], out_file=None)
graph = graph_from_dot_data(dot_data)
graph.write_png('tree.png')
```
上述代码使用了Python的scikit-learn库中的DecisionTreeClassifier类来实现CART算法。首先,我们加载了鸢尾花数据集,并将其分为特征矩阵X和目标向量y。然后,我们使用DecisionTreeClassifier类来训练模型,并指定了基尼指数作为分裂标准,最大深度为4,随机种子为1。最后,我们使用export_graphviz函数将决策树导出为Graphviz格式,并使用graph_from_dot_data函数将其转换为PNG格式。
阅读全文