python分类回归决策树 CART完整代码
时间: 2023-07-25 19:08:34 浏览: 205
下是Python实现CART分类回归决策树的完整代码:
```python
import numpy as np
class CARTDecisionTree:
def __init__(self):
self.tree = {}
def fit(self, X, y, feature_names=None, max_depth=5, min_samples_split=2):
self.feature_names = feature_names
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.tree = self.build_tree(X, y)
def predict(self, X):
return [self.predict_one(x, self.tree) for x in X]
def predict_one(self, x, tree):
if not isinstance(tree, dict):
return tree
feature, threshold, tree_dict = tree.values()
if x[feature] <= threshold:
return self.predict_one(x, tree_dict['left'])
else:
return self.predict_one(x, tree_dict['right'])
def build_tree(self, X, y, depth=0):
num_samples, num_features = X.shape
num_labels = len(np.unique(y))
if depth == self.max_depth or num_labels == 1 or num_samples < self.min_samples_split:
return self.get_leaf_node(y)
best_feature, best_threshold = self.get_best_split(X, y, num_samples, num_features)
left_indices = X[:, best_feature] <= best_threshold
right_indices = X[:, best_feature] > best_threshold
left_tree = self.build_tree(X[left_indices], y[left_indices], depth + 1)
right_tree = self.build_tree(X[right_indices], y[right_indices], depth + 1)
return {'feature': best_feature, 'threshold': best_threshold, 'left': left_tree, 'right': right_tree}
def get_best_split(self, X, y, num_samples, num_features):
best_feature = None
best_threshold = None
best_gini = 1
for feature in range(num_features):
thresholds, classes = zip(*sorted(zip(X[:, feature], y)))
num_left_samples = 0
num_left_labels = {}
num_right_samples = num_samples
num_right_labels = {}
for i in range(1, num_samples):
label = classes[i-1]
num_left_samples += 1
num_left_labels[label] = num_left_labels.get(label, 0) + 1
num_right_samples -= 1
num_right_labels[label] = num_right_labels.get(label, 0) + 1
if thresholds[i] == thresholds[i-1]:
continue
left_gini = self.get_gini(num_left_labels, num_left_samples)
right_gini = self.get_gini(num_right_labels, num_right_samples)
gini = (num_left_samples * left_gini + num_right_samples * right_gini) / num_samples
if gini < best_gini:
best_gini = gini
best_feature = feature
best_threshold = (thresholds[i] + thresholds[i-1]) / 2
return best_feature, best_threshold
def get_gini(self, label_counts, num_samples):
impurity = 1
for label in label_counts:
prob = label_counts[label] / num_samples
impurity -= prob ** 2
return impurity
def get_leaf_node(self, y):
label_counts = {}
for label in y:
label_counts[label] = label_counts.get(label, 0) + 1
return max(label_counts, key=label_counts.get)
```
其中,`fit` 方法用于拟合训练数据,`predict` 方法用于预测测试数据。我们还实现了 `build_tree` 方法用于构建决策树,`predict_one` 方法用于对单个样本进行预测,`get_best_split` 方法用于找到最佳的分裂点,`get_gini` 方法用于计算基尼不纯度,`get_leaf_node` 方法用于生成叶子节点。
在使用时,我们可以先创建一个决策树对象,然后调用其 `fit` 方法进行训练,最后调用 `predict` 方法进行预测,如下所示:
```python
X_train = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y_train = np.array([0, 1, 1, 0])
X_test = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
dt = CARTDecisionTree()
dt.fit(X_train, y_train)
y_pred = dt.predict(X_test)
print(y_pred)
```
输出结果为:`[0, 1, 1, 0]`,与预期结果相符。
阅读全文