分类回归树的代码
时间: 2023-08-06 08:17:30 浏览: 31
以下是一个简单的分类回归树的Python代码示例:
```python
import numpy as np
class DecisionTree:
def __init__(self, max_depth=None, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.tree = {}
def fit(self, X, y, depth=0):
n_samples, n_features = X.shape
if n_samples >= self.min_samples_split and depth <= self.max_depth:
best_feature, best_threshold = self._find_best_split(X, y, n_samples, n_features)
if best_feature is not None and best_threshold is not None:
left_indices = np.where(X[:, best_feature] <= best_threshold)[0]
right_indices = np.where(X[:, best_feature] > best_threshold)[0]
self.tree = {'feature': best_feature,
'threshold': best_threshold,
'left': self.fit(X[left_indices], y[left_indices], depth+1),
'right': self.fit(X[right_indices], y[right_indices], depth+1)}
return self._calculate_leaf_value(y)
def predict(self, X):
return np.array([self._traverse_tree(x, self.tree) for x in X])
def _find_best_split(self, X, y, n_samples, n_features):
best_gain = -np.inf
best_feature = None
best_threshold = None
for feature in range(n_features):
feature_values = X[:, feature]
for threshold in np.unique(feature_values):
left_indices = np.where(feature_values <= threshold)[0]
right_indices = np.where(feature_values > threshold)[0]
if len(left_indices) > 0 and len(right_indices) > 0:
y_left = y[left_indices]
y_right = y[right_indices]
gain = self._information_gain(y, y_left, y_right)
if gain > best_gain:
best_gain = gain
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold
def _calculate_leaf_value(self, y):
return np.mean(y)
def _information_gain(self, y, y_left, y_right):
parent_entropy = self._entropy(y)
left_entropy = self._entropy(y_left)
right_entropy = self._entropy(y_right)
n = len(y)
n_l, n_r = len(y_left), len(y_right)
child_entropy = (n_l/n) * left_entropy + (n_r/n) * right_entropy
return parent_entropy - child_entropy
def _entropy(self, y):
counts = np.bincount(y)
probabilities = counts / len(y)
return -np.sum([p*np.log2(p) for p in probabilities if p > 0])
def _traverse_tree(self, x, tree):
if 'feature' not in tree:
return tree['value']
if x[tree['feature']] <= tree['threshold']:
return self._traverse_tree(x, tree['left'])
else:
return self._traverse_tree(x, tree['right'])
```
这里实现的是一个简单的决策树,其每个节点的分裂是通过找到最佳分割点来实现的,最佳分割点是通过计算信息增益来选择的。另外,还可以设置最大深度和最小样本数来控制树的大小和过拟合问题。