sklearn决策树散点图代码
时间: 2024-10-11 09:01:18 浏览: 30
在Scikit-learn中,创建决策树的散点图(也称为决策边界)需要先训练一个决策树模型,然后利用`plot_tree`函数来可视化。这里是一个简单的例子,假设我们有鸢尾花数据集(Iris dataset)并使用`DecisionTreeClassifier`训练了一个模型:
```python
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
# 加载数据
iris = load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练决策树
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
# 可视化决策树
plt.figure(figsize=(10, 6))
tree.plot_tree(clf, filled=True, class_names=iris.target_names, feature_names=iris.feature_names)
plt.title("Decison Tree for Iris Dataset")
plt.show()
```
这段代码首先加载了鸢尾花数据,接着划分了训练集和测试集,然后用训练数据训练了一个决策树模型。最后通过`plot_tree`方法绘制了决策树,并显示出来。
如果你想要看到每个样本点在二维空间(比如前两个主成分)中的位置及其所属的决策路径,可以使用`export_graphviz`将决策树导出为图形文件,然后用Graphviz软件打开,结合其他工具(如`pydot`)读取并渲染成图像。
注意:实际操作时可能需要安装额外的依赖,如`graphviz`和`pydotplus`。
阅读全文