python CART回归树
时间: 2024-10-13 17:00:50 浏览: 12
Python中的CART回归树,全称为Classification And Regression Trees,是一种用于预测数值型变量(回归问题)或分类问题(分类问题)的决策树算法。它由Leo Breiman等人开发,并通过sklearn库在Python中得到了广泛的应用。
CART回归树的工作原理是将数据集划分为一系列越来越小、越来越纯净的子集(即划分节点),每个子集对应于一个特征值。对于连续特征,划分通常是基于某个阈值;对于离散特征,则直接比较其取值。这个过程一直持续到满足停止条件(如达到最大深度,样本数过少等),然后在叶子节点处进行平均(回归问题)或多数投票(分类问题)来进行预测。
在Python中,你可以使用`sklearn.tree`模块中的`DecisionTreeRegressor`类来构建CART回归树模型,而`plot_tree()`函数则可以帮助可视化树结构。下面是一个简单的例子:
```python
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
# 加载波士顿房价数据
boston = load_boston()
X, y = boston.data, boston.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 创建并训练模型
model = DecisionTreeRegressor(random_state=42)
model.fit(X_train, y_train)
# 可视化回归树
plt.figure(figsize=(25, 10))
_ = plot_tree(model, filled=True)
plt.show()
# 预测
predictions = model.predict(X_test)
```
阅读全文