鸢尾花数据集画出决策树
时间: 2024-09-19 21:00:46 浏览: 33
鸢尾花数据集(Iris dataset)是一个经典的机器学习数据集,常用于初学者学习分类算法。它包含了三种不同种类的鸢尾花(Setosa、Versicolour和Virginica)的测量值,如花瓣长度、花瓣宽度等特征。绘制决策树通常涉及使用像Scikit-Learn这样的库中的`DecisionTreeClassifier`模型,并通过`plot_tree()`函数可视化。
首先,你需要导入必要的库:
```python
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
```
然后加载数据并预处理:
```python
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=42)
```
接下来训练决策树模型:
```python
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)
```
最后,绘制决策树:
```python
plt.figure(figsize=(15, 10))
plot_tree(clf, filled=True, class_names=iris.target_names, feature_names=iris.feature_names)
plt.title("Iris Dataset Decision Tree")
plt.show()
```
这将显示一个基于鸢尾花特征和标签划分的决策树图。