python实现cart回归树
时间: 2023-12-04 16:04:40 浏览: 178
CART(Classification and Regression Tree)回归树是一种常用的决策树算法,它在每个节点进行二分,既可以用于分类问题又可以用于回归问题。
下面是一个简单的Python实现CART回归树的示例代码:
```python
import numpy as np
class CARTRegressionTree:
def __init__(self, max_depth=5, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.tree = None
def fit(self, X, y):
self.tree = self.build_tree(X, y, depth=1)
def build_tree(self, X, y, depth):
n_samples, n_features = X.shape
if n_samples >= self.min_samples_split and depth <= self.max_depth:
best_feature, best_threshold = self.get_best_split(X, y)
if best_feature is not None and best_threshold is not None:
left_X, left_y, right_X, right_y = self.split(X, y, best_feature, best_threshold)
left_tree = self.build_tree(left_X, left_y, depth+1)
right_tree = self.build_tree(right_X, right_y, depth+1)
return {"feature": best_feature, "threshold": best_threshold, "left": left_tree, "right": right_tree}
leaf_value = self.get_leaf_value(y)
return {"leaf_value": leaf_value}
def get_best_split(self, X, y):
best_gain = -float("inf")
best_feature = best_threshold = None
n_samples, n_features = X.shape
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left_y = y[X[:, feature] <= threshold]
right_y = y[X[:, feature] > threshold]
if len(left_y) > 0 and len(right_y) > 0:
gain = self.get_gain(y, left_y, right_y)
if gain > best_gain:
best_gain = gain
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold
def get_gain(self, parent, left, right):
weighted_mean = len(left)/len(parent) * np.mean(left) + len(right)/len(parent) * np.mean(right)
return np.sum((parent - weighted_mean)**2) - np.sum((left - np.mean(left))**2) - np.sum((right - np.mean(right))**2)
def split(self, X, y, feature, threshold):
left_X = X[X[:, feature] <= threshold]
left_y = y[X[:, feature] <= threshold]
right_X = X[X[:, feature] > threshold]
right_y = y[X[:, feature] > threshold]
return left_X, left_y, right_X, right_y
def get_leaf_value(self, y):
return np.mean(y)
def predict(self, X):
return np.array([self.predict_one(x, self.tree) for x in X])
def predict_one(self, x, tree):
if "leaf_value" in tree:
return tree["leaf_value"]
if x[tree["feature"]] <= tree["threshold"]:
return self.predict_one(x, tree["left"])
else:
return self.predict_one(x, tree["right"])
```
其中,`max_depth`和`min_samples_split`是控制树的深度和最小分割样本数的超参数,`fit`方法用于训练模型,`predict`方法用于预测新样本。具体实现细节可以参考代码注释。
阅读全文