用python编写一个CART算法
时间: 2023-09-15 09:20:56 浏览: 85
python 实现决策树算法(CART)
CART(Classification and Regression Trees)算法是一种常见的决策树算法,可以应用于分类和回归问题。下面是用Python编写的CART算法示例代码:
```python
import numpy as np
class CART:
def __init__(self, max_depth=5, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.tree = None
def split(self, X, y, feature_idx, threshold):
left_mask = X[:, feature_idx] <= threshold
right_mask = X[:, feature_idx] > threshold
left_X, left_y = X[left_mask], y[left_mask]
right_X, right_y = X[right_mask], y[right_mask]
return left_X, left_y, right_X, right_y
def gini_score(self, y):
classes, counts = np.unique(y, return_counts=True)
gini = 1 - sum((counts[i]/len(y))**2 for i in range(len(classes)))
return gini
def get_best_split(self, X, y):
best_feature_idx, best_threshold, best_gini = None, None, np.inf
for feature_idx in range(X.shape[1]):
unique_values = np.unique(X[:, feature_idx])
for threshold in unique_values:
left_X, left_y, right_X, right_y = self.split(X, y, feature_idx, threshold)
if len(left_y) < self.min_samples_split or len(right_y) < self.min_samples_split:
continue
gini_left, gini_right = self.gini_score(left_y), self.gini_score(right_y)
weighted_gini = (len(left_y)/len(y)) * gini_left + (len(right_y)/len(y)) * gini_right
if weighted_gini < best_gini:
best_feature_idx, best_threshold, best_gini = feature_idx, threshold, weighted_gini
return best_feature_idx, best_threshold, best_gini
def fit(self, X, y, depth=1):
if depth > self.max_depth or len(y) < self.min_samples_split:
return np.bincount(y).argmax()
best_feature_idx, best_threshold, _ = self.get_best_split(X, y)
if best_feature_idx is None:
return np.bincount(y).argmax()
left_X, left_y, right_X, right_y = self.split(X, y, best_feature_idx, best_threshold)
node = {'feature_idx': best_feature_idx, 'threshold': best_threshold}
node['left'] = self.fit(left_X, left_y, depth+1)
node['right'] = self.fit(right_X, right_y, depth+1)
self.tree = node
return node
def predict(self, X):
node = self.tree
while isinstance(node, dict):
if X[node['feature_idx']] <= node['threshold']:
node = node['left']
else:
node = node['right']
return node
```
该示例代码定义了一个CART类,包含以下方法:
- `__init__(self, max_depth=5, min_samples_split=2)`: 初始化方法,可传入最大树深度和最小样本数。
- `split(self, X, y, feature_idx, threshold)`: 根据特征和阈值将数据集划分为左右两部分。
- `gini_score(self, y)`: 计算基尼指数。
- `get_best_split(self, X, y)`: 寻找最佳的特征和阈值来划分数据集。
- `fit(self, X, y, depth=1)`: 构建决策树。
- `predict(self, X)`: 对新数据进行预测。
示例代码中的`fit`方法使用递归方式构建决策树,先使用`get_best_split`方法找到最佳的特征和阈值进行划分,然后递归构建左右子树,最后返回根节点。`predict`方法使用构建好的决策树对新数据进行预测,逐层遍历树节点,直到找到叶子节点。
阅读全文