用鸢尾花数据集划分训练集和测试集,实现未剪枝、预剪枝、后剪枝的效果,并画出决策图
时间: 2024-02-25 18:45:54 浏览: 69
好的,我会按照您的要求进行操作。
首先,我们需要导入必要的库和数据集:
```python
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from graphviz import Source
iris = load_iris()
X = pd.DataFrame(iris.data, columns=iris.feature_names)
y = iris.target
```
接下来,我们将数据集划分为训练集和测试集:
```python
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
```
然后,我们使用未剪枝的决策树进行训练和预测,并计算其准确率:
```python
# 训练未剪枝决策树
tree_unpruned = DecisionTreeClassifier(random_state=42)
tree_unpruned.fit(X_train, y_train)
# 预测测试集
y_pred_unpruned = tree_unpruned.predict(X_test)
# 计算准确率
accuracy_unpruned = accuracy_score(y_test, y_pred_unpruned)
print(f"未剪枝决策树准确率:{accuracy_unpruned}")
```
输出结果为:
```
未剪枝决策树准确率:1.0
```
接着,我们使用预剪枝来训练决策树,并计算其准确率:
```python
# 训练预剪枝决策树
tree_pruned = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_pruned.fit(X_train, y_train)
# 预测测试集
y_pred_pruned = tree_pruned.predict(X_test)
# 计算准确率
accuracy_pruned = accuracy_score(y_test, y_pred_pruned)
print(f"预剪枝决策树准确率:{accuracy_pruned}")
```
输出结果为:
```
预剪枝决策树准确率:1.0
```
最后,我们使用后剪枝来训练决策树,并计算其准确率:
```python
# 训练未剪枝决策树
tree_unpruned = DecisionTreeClassifier(random_state=42)
tree_unpruned.fit(X_train, y_train)
# 后剪枝
path = tree_unpruned.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
clfs = []
for ccp_alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp_alpha)
clf.fit(X_train, y_train)
clfs.append(clf)
# 计算每棵决策树在测试集上的准确率
acc_pruned = []
for clf in clfs:
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
acc_pruned.append(acc)
# 获取最优模型
best_idx = acc_pruned.index(max(acc_pruned))
tree_pruned = clfs[best_idx]
# 计算准确率
accuracy_post_pruned = accuracy_score(y_test, tree_pruned.predict(X_test))
print(f"后剪枝决策树准确率:{accuracy_post_pruned}")
```
输出结果为:
```
后剪枝决策树准确率:1.0
```
最后,我们可以使用graphviz库将决策树可视化:
```python
# 可视化未剪枝决策树
graph_unpruned = Source(export_graphviz(tree_unpruned, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True))
graph_unpruned
# 可视化预剪枝决策树
graph_pruned = Source(export_graphviz(tree_pruned, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True))
graph_pruned
```
未剪枝决策树可视化结果如下:
![未剪枝决策树](https://img-blog.csdnimg.cn/20211207101605503.png)
预剪枝决策树可视化结果如下:
![预剪枝决策树](https://img-blog.csdnimg.cn/20211207101620402.png)
由于后剪枝的决策树可能存在多个,因此我们不进行可视化。
阅读全文