实现CART算法的完整python代码
时间: 2023-08-06 15:03:24 浏览: 101
以下是实现CART算法的完整Python代码:
```python
import numpy as np
from collections import Counter
class CARTDecisionTree:
def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.tree = None
def fit(self, X, y):
self.tree = self.build_tree(X, y)
def predict(self, X):
return np.array([self.predict_one(x, self.tree) for x in X])
def predict_one(self, x, node):
if node.is_leaf:
return node.label
else:
if x[node.feature] <= node.threshold:
return self.predict_one(x, node.left)
else:
return self.predict_one(x, node.right)
def build_tree(self, X, y, depth=0):
if depth == self.max_depth or y.size < self.min_samples_split or len(set(y)) == 1:
return Node(self.get_label(y), is_leaf=True)
best_feature, best_threshold = self.get_best_split(X, y)
if best_feature is None or best_threshold is None:
return Node(self.get_label(y), is_leaf=True)
left_indices = X[:, best_feature] <= best_threshold
right_indices = X[:, best_feature] > best_threshold
left = self.build_tree(X[left_indices], y[left_indices], depth+1)
right = self.build_tree(X[right_indices], y[right_indices], depth+1)
return Node(None, feature=best_feature, threshold=best_threshold, left=left, right=right)
def get_best_split(self, X, y):
best_feature, best_threshold = None, None
best_gini = 1.0
for feature in range(X.shape[1]):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
gini = self.gini_index(X, y, feature, threshold)
if gini < best_gini:
best_feature, best_threshold, best_gini = feature, threshold, gini
if best_gini == 1.0:
return None, None
else:
return best_feature, best_threshold
def gini_index(self, X, y, feature, threshold):
left_indices = X[:, feature] <= threshold
right_indices = X[:, feature] > threshold
left_gini = self.get_gini(y[left_indices])
right_gini = self.get_gini(y[right_indices])
n = y.size
left_size, right_size = y[left_indices].size, y[right_indices].size
gini = (left_size/n)*left_gini + (right_size/n)*right_gini
return gini
def get_gini(self, y):
counter = Counter(y)
return 1 - sum([(v/len(y))**2 for v in counter.values()])
def get_label(self, y):
counter = Counter(y)
most_common = counter.most_common(1)[0][0]
return most_common
class Node:
def __init__(self, label, feature=None, threshold=None, left=None, right=None, is_leaf=False):
self.label = label
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
self.is_leaf = is_leaf
```
这个代码实现了CART决策树的训练和预测过程,可以通过简单的调用`fit`方法来训练模型并用`predict`方法来进行预测。在初始化时,可以指定树的最大深度、最小样本数以及最小叶子节点样本数等参数。