cart分类树后剪枝代码
时间: 2023-09-11 15:09:53 浏览: 36
以下是一个简单的 CART 分类树后剪枝的代码实现:
```python
import numpy as np
class DecisionTree():
def __init__(self, max_depth=5, min_samples_leaf=1, criterion='gini', alpha=0.1):
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.criterion = criterion
self.alpha = alpha
def fit(self, X, y):
self.tree = self._build_tree(X, y, depth=0)
def predict(self, X):
return np.array([self._predict(inputs) for inputs in X])
def _build_tree(self, X, y, depth):
num_samples, num_features = X.shape
num_labels = len(np.unique(y))
# Check termination criteria
if (depth >= self.max_depth or num_samples < self.min_samples_leaf or num_labels == 1):
leaf_value = self._calculate_leaf_value(y)
return Node(leaf_value=leaf_value)
# Find best split
best_feature, best_threshold = self._find_best_split(X, y, num_labels)
# Split data
left_idxs = np.argwhere(X[:, best_feature] <= best_threshold).flatten()
right_idxs = np.argwhere(X[:, best_feature] > best_threshold).flatten()
# Recursive call for left and right sub-tree
left_subtree = self._build_tree(X[left_idxs], y[left_idxs], depth+1)
right_subtree = self._build_tree(X[right_idxs], y[right_idxs], depth+1)
# Create node with best split
return Node(best_feature=best_feature, best_threshold=best_threshold, left_subtree=left_subtree, right_subtree=right_subtree)
def _find_best_split(self, X, y, num_labels):
best_feature = None
best_threshold = None
best_impurity = 1.0
for feature_idx in range(X.shape[1]):
feature_values = X[:, feature_idx]
unique_values = np.unique(feature_values)
for threshold in unique_values:
# Split data
left_idxs = np.argwhere(feature_values <= threshold).flatten()
right_idxs = np.argwhere(feature_values > threshold).flatten()
# Check if split is valid
if len(left_idxs) == 0 or len(right_idxs) == 0:
continue
# Calculate impurity
impurity = self._calculate_impurity(y, num_labels, left_idxs, right_idxs)
# Update best split if necessary
if impurity < best_impurity:
best_feature = feature_idx
best_threshold = threshold
best_impurity = impurity
return best_feature, best_threshold
def _calculate_impurity(self, y, num_labels, left_idxs, right_idxs):
num_left = len(left_idxs)
num_right = len(right_idxs)
# Calculate impurity of left and right node
if self.criterion == 'gini':
left_impurity = self._gini_impurity(y[left_idxs], num_labels)
right_impurity = self._gini_impurity(y[right_idxs], num_labels)
elif self.criterion == 'entropy':
left_impurity = self._entropy(y[left_idxs], num_labels)
right_impurity = self._entropy(y[right_idxs], num_labels)
else:
raise ValueError('Invalid criterion')
# Weighted sum of impurities
impurity = (num_left/(num_left+num_right))*left_impurity + (num_right/(num_left+num_right))*right_impurity
return impurity
def _calculate_leaf_value(self, y):
# Calculate most common class label
labels, counts = np.unique(y, return_counts=True)
idx = np.argmax(counts)
return labels[idx]
def _predict(self, inputs):
node = self.tree
while node.left_subtree:
if inputs[node.best_feature] <= node.best_threshold:
node = node.left_subtree
else:
node = node.right_subtree
return node.leaf_value
def _gini_impurity(self, y, num_labels):
impurity = 1.0
_, counts = np.unique(y, return_counts=True)
for count in counts:
impurity -= (count/len(y))**2
return impurity
def _entropy(self, y, num_labels):
entropy = 0.0
_, counts = np.unique(y, return_counts=True)
for count in counts:
probability = count/len(y)
entropy -= probability * np.log2(probability)
return entropy
def _calculate_pruning_alpha(self, y, y_hat, num_leaves):
error = np.sum(y != y_hat) / len(y)
alpha = (error - 1/num_leaves) / (1 - 1/num_leaves)
return alpha
def _prune_tree(self, node, X, y):
if node.left_subtree:
node.left_subtree = self._prune_tree(node.left_subtree, X, y)
node.right_subtree = self._prune_tree(node.right_subtree, X, y)
# Check if node can be pruned
if not node.left_subtree and not node.right_subtree:
y_hat = np.array([self._predict(inputs) for inputs in X])
num_leaves = self._count_leaves(self.tree)
alpha = self._calculate_pruning_alpha(y, y_hat, num_leaves)
if alpha < self.alpha:
leaf_value = self._calculate_leaf_value(y)
return Node(leaf_value=leaf_value)
return node
def prune(self, X, y):
self.tree = self._prune_tree(self.tree, X, y)
def _count_leaves(self, node):
if node.left_subtree:
return self._count_leaves(node.left_subtree) + self._count_leaves(node.right_subtree)
else:
return 1
class Node():
def __init__(self, best_feature=None, best_threshold=None, left_subtree=None, right_subtree=None, leaf_value=None):
self.best_feature = best_feature
self.best_threshold = best_threshold
self.left_subtree = left_subtree
self.right_subtree = right_subtree
self.leaf_value = leaf_value
```
在上述代码中,我们添加了一个 `_calculate_pruning_alpha` 函数来计算修剪 α 值,以及一个 `_prune_tree` 函数来递归地修剪树。在 `prune` 函数中,我们将调用 `_prune_tree` 来修剪整个树。