python cart树_CART回归树python实现
时间: 2023-10-11 20:10:54 浏览: 110
下面是一个简单的 Python 实现 CART 回归树的例子:
```python
import numpy as np
class CARTRegressionTree:
def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
def fit(self, X, y):
self.n_features_ = X.shape[1]
self.tree_ = self._build_tree(X, y, depth=0)
def predict(self, X):
return np.array([self._predict(inputs) for inputs in X])
def _build_tree(self, X, y, depth):
n_samples, n_features = X.shape
if depth >= self.max_depth or n_samples < self.min_samples_split:
return self._make_leaf_node(y)
best_feature, best_threshold = self._find_split(X, y, n_samples, n_features)
if best_feature is None:
return self._make_leaf_node(y)
left_indices = X[:, best_feature] < best_threshold
right_indices = ~left_indices
left_tree = self._build_tree(X[left_indices], y[left_indices], depth+1)
right_tree = self._build_tree(X[right_indices], y[right_indices], depth+1)
return self._make_decision_node(best_feature, best_threshold, left_tree, right_tree)
def _find_split(self, X, y, n_samples, n_features):
best_feature, best_threshold, best_mse = None, None, np.inf
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left_indices = X[:, feature] < threshold
right_indices = ~left_indices
if sum(left_indices) >= self.min_samples_leaf and sum(right_indices) >= self.min_samples_leaf:
left_mse = self._mean_squared_error(y[left_indices])
right_mse = self._mean_squared_error(y[right_indices])
mse = left_mse + right_mse
if mse < best_mse:
best_feature, best_threshold, best_mse = feature, threshold, mse
return best_feature, best_threshold
def _mean_squared_error(self, y):
return np.mean((y - np.mean(y))**2)
def _make_decision_node(self, feature, threshold, left_tree, right_tree):
return {'feature': feature,
'threshold': threshold,
'left_tree': left_tree,
'right_tree': right_tree}
def _make_leaf_node(self, y):
return {'leaf': np.mean(y)}
def _predict(self, inputs):
node = self.tree_
while isinstance(node, dict):
if inputs[node['feature']] < node['threshold']:
node = node['left_tree']
else:
node = node['right_tree']
return node['leaf']
```
在上面的代码中,`CARTRegressionTree` 类实现了 CART 回归树的训练和预测方法。其中,`max_depth`、`min_samples_split` 和 `min_samples_leaf` 是可调参数,分别控制树的最大深度、节点分裂所需的最小样本数和叶节点所需的最小样本数。`fit` 方法用于训练模型,`predict` 方法用于预测新样本。
在 `_build_tree` 方法中,我们递归地构建决策树。如果当前节点样本数少于 `min_samples_split` 或者树的深度达到了 `max_depth`,则返回当前节点的平均值作为叶节点。否则,我们尝试在所有特征和阈值的组合中找到一个最优的分裂点,使得分裂后的左右子树的均方误差之和最小。如果无法找到合适的分裂点,则返回当前节点的平均值作为叶节点。
在 `_find_split` 方法中,我们遍历所有特征和阈值的组合,计算分裂后的左右子树的均方误差之和。如果左右子树的样本数都大于等于 `min_samples_leaf`,则更新最优的分裂点。
在 `_make_decision_node` 和 `_make_leaf_node` 方法中,我们分别创建决策节点和叶节点。
在 `_predict` 方法中,我们遍历决策树,根据当前节点的特征和阈值判断向左还是向右走,直到到达叶节点,返回叶节点的平均值作为预测值。
阅读全文