cart决策树python代码
时间: 2023-06-21 18:14:44 浏览: 123
下面是一个简单的 CART 决策树的 Python 实现:
```python
import numpy as np
class CARTDecisionTree:
def __init__(self, max_depth=10, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
def fit(self, X, y):
self.tree = self.build_tree(X, y)
def build_tree(self, X, y, depth=0):
n_samples, n_features = X.shape
n_labels = len(np.unique(y))
# 检查是否应该停止分裂
if (depth >= self.max_depth or
n_labels == 1 or
n_samples < self.min_samples_split):
return np.argmax(np.bincount(y))
# 寻找最佳分割特征和阈值
best_feature, best_threshold = self.get_best_split(X, y, n_samples, n_features)
# 分割样本集并递归构建子树
left_indices = X[:, best_feature] < best_threshold
right_indices = X[:, best_feature] >= best_threshold
left_subtree = self.build_tree(X[left_indices], y[left_indices], depth+1)
right_subtree = self.build_tree(X[right_indices], y[right_indices], depth+1)
return {'feature': best_feature, 'threshold': best_threshold,
'left': left_subtree, 'right': right_subtree}
def get_best_split(self, X, y, n_samples, n_features):
best_gini = float('inf')
best_feature, best_threshold = None, None
# 遍历所有特征和阈值,找到最佳分割
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left_indices = X[:, feature] < threshold
right_indices = X[:, feature] >= threshold
if (len(left_indices) == 0 or len(right_indices) == 0):
continue
gini = self.gini_index(y, left_indices, right_indices)
if gini < best_gini:
best_gini = gini
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold
def gini_index(self, y, left_indices, right_indices):
n_left, n_right = len(left_indices), len(right_indices)
gini_left, gini_right = 0, 0
if n_left > 0:
labels_left, counts_left = np.unique(y[left_indices], return_counts=True)
gini_left = 1 - np.sum((counts_left / n_left) ** 2)
if n_right > 0:
labels_right, counts_right = np.unique(y[right_indices], return_counts=True)
gini_right = 1 - np.sum((counts_right / n_right) ** 2)
gini = (n_left * gini_left + n_right * gini_right) / (n_left + n_right)
return gini
def predict(self, X):
return np.array([self.predict_sample(x, self.tree) for x in X])
def predict_sample(self, x, tree):
if isinstance(tree, int):
return tree
feature, threshold = tree['feature'], tree['threshold']
if x[feature] < threshold:
return self.predict_sample(x, tree['left'])
else:
return self.predict_sample(x, tree['right'])
```
需要注意的是,上述代码实现的 CART 决策树仅支持分类问题。如果要用于回归问题,需要对 `gini_index` 方法进行修改,使用其他的评估指标(如 MSE)。
阅读全文