cart回归树代码
时间: 2023-08-06 09:14:10 浏览: 178
下面是一个简单的 CART 回归树的 Python 代码实现:
```python
import numpy as np
class CartRegressionTree:
def __init__(self, max_depth=2, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.tree = None
def split(self, X, y):
best_feature_index, best_threshold, best_loss = None, None, np.inf
for feature_index in range(X.shape[1]):
feature_values = X[:, feature_index]
for threshold in feature_values:
left_indices = feature_values <= threshold
right_indices = feature_values > threshold
if len(y[left_indices]) >= self.min_samples_split and len(y[right_indices]) >= self.min_samples_split:
left_loss = np.mean((y[left_indices] - np.mean(y[left_indices])) ** 2)
right_loss = np.mean((y[right_indices] - np.mean(y[right_indices])) ** 2)
loss = left_loss + right_loss
if loss < best_loss:
best_feature_index = feature_index
best_threshold = threshold
best_loss = loss
return best_feature_index, best_threshold
def build_tree(self, X, y, depth):
if depth == self.max_depth or len(y) < self.min_samples_split:
return np.mean(y)
feature_index, threshold = self.split(X, y)
if feature_index is None:
return np.mean(y)
left_indices = X[:, feature_index] <= threshold
right_indices = X[:, feature_index] > threshold
left_subtree = self.build_tree(X[left_indices], y[left_indices], depth+1)
right_subtree = self.build_tree(X[right_indices], y[right_indices], depth+1)
return (feature_index, threshold, left_subtree, right_subtree)
def fit(self, X, y):
self.tree = self.build_tree(X, y, 0)
def predict(self, X):
if self.tree is None:
raise Exception("Tree has not been trained yet!")
predictions = []
for sample in X:
node = self.tree
while isinstance(node, tuple):
if sample[node[0]] <= node[1]:
node = node[2]
else:
node = node[3]
predictions.append(node)
return np.array(predictions)
```
其中,`max_depth` 表示树的最大深度,`min_samples_split` 表示每个节点最少需要多少个样本才能继续分裂。`split` 函数用于在当前节点上找到最佳的分裂特征和阈值,`build_tree` 函数递归地构建树,`predict` 函数用于对新数据进行预测。
阅读全文