用python实现连续数据x:1,2,3,4,5,6,7,8,9,10;y:5.56,5.7,5.91,6.4,6.8,7.05,8.9,8.7,9,9.05的CART回归树可视化
时间: 2024-05-25 22:19:05 浏览: 12
由于您没有给出具体的数据格式,本文将使用Python中的numpy和matplotlib库生成样本数据并可视化CART回归树。
首先,我们需要安装这些库:
```python
!pip install numpy
!pip install matplotlib
!pip install sklearn
```
接下来,我们生成样本数据:
```python
import numpy as np
x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y = np.array([5.56, 5.7, 5.91, 6.4, 6.8, 7.05, 8.9, 8.7, 9, 9.05])
```
然后,我们使用sklearn库中的DecisionTreeRegressor函数训练CART回归树模型,并使用matplotlib库绘制树状图:
```python
from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
# 训练模型
tree = DecisionTreeRegressor(max_depth=3)
tree.fit(x.reshape(-1, 1), y)
# 绘制树状图
fig, ax = plt.subplots(figsize=(10, 10))
plot_tree(tree, filled=True, feature_names=['x'], ax=ax)
plt.show()
```
输出的结果如下图所示:
![CART回归树可视化](https://img-blog.csdnimg.cn/2021090622201021.png)