C45决策树python代码
时间: 2024-04-03 15:29:55 浏览: 217
C4.5决策树是一种经典的机器学习算法,用于分类和回归任务。下面是一个简单的C4.5决策树的Python代码示例:
```python
import numpy as np
from collections import Counter
class Node:
def __init__(self, feature=None, threshold=None, label=None):
self.feature = feature # 分裂特征的索引
self.threshold = threshold # 分裂特征的阈值
self.label = label # 叶节点的类别
self.left = None # 左子树
self.right = None # 右子树
class C45DecisionTree:
def __init__(self, min_samples_split=2, max_depth=float('inf'), min_info_gain=1e-7):
self.min_samples_split = min_samples_split # 分裂所需的最小样本数
self.max_depth = max_depth # 树的最大深度
self.min_info_gain = min_info_gain # 分裂所需的最小信息增益
def _calculate_entropy(self, y):
counter = Counter(y)
probs = [count / len(y) for count in counter.values()]
entropy = -sum(p * np.log2(p) for p in probs)
return entropy
def _calculate_info_gain(self, X, y, feature, threshold):
left_mask = X[:, feature] <= threshold
right_mask = X[:, feature] > threshold
left_entropy = self._calculate_entropy(y[left_mask])
right_entropy = self._calculate_entropy(y[right_mask])
info_gain = self._calculate_entropy(y) - (left_entropy * sum(left_mask) / len(y) + right_entropy * sum(right_mask) / len(y))
return info_gain
def _split(self, X, y):
m, n = X.shape
best_info_gain = 0
best_feature = None
best_threshold = None
for feature in range(n):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
info_gain = self._calculate_info_gain(X, y, feature, threshold)
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold
def _build_tree(self, X, y, depth):
if len(set(y)) == 1:
return Node(label=y[0])
if depth >= self.max_depth or len(X) < self.min_samples_split:
counter = Counter(y)
most_common_label = counter.most_common(1)[0][0]
return Node(label=most_common_label)
feature, threshold = self._split(X, y)
if feature is None or threshold is None:
counter = Counter(y)
most_common_label = counter.most_common(1)[0][0]
return Node(label=most_common_label)
left_mask = X[:, feature] <= threshold
right_mask = X[:, feature] > threshold
left_subtree = self._build_tree(X[left_mask], y[left_mask], depth + 1)
right_subtree = self._build_tree(X[right_mask], y[right_mask], depth + 1)
return Node(feature=feature, threshold=threshold, left=left_subtree, right=right_subtree)
def fit(self, X, y):
self.tree = self._build_tree(X, y, 0)
def _predict_single(self, x, node):
if node.label is not None:
return node.label
if x[node.feature] <= node.threshold:
return self._predict_single(x, node.left)
else:
return self._predict_single(x, node.right)
def predict(self, X):
return [self._predict_single(x, self.tree) for x in X]
```
这段代码实现了一个基本的C4.5决策树算法,包括计算熵、计算信息增益、选择最佳分裂特征和阈值、构建决策树等功能。你可以使用`fit`方法来训练模型,使用`predict`方法来进行预测。
阅读全文