c45决策树模型代码
时间: 2024-04-15 13:22:27 浏览: 158
C4.5决策树代码
3星 · 编辑精心推荐
C4.5决策树是一种经典的机器学习算法,用于分类和回归任务。下面是一个简单的C4.5决策树模型的代码示例:
```python
import numpy as np
class Node:
def __init__(self, feature=None, threshold=None, label=None):
self.feature = feature # 分裂特征的索引
self.threshold = threshold # 分裂特征的阈值
self.label = label # 叶节点的类别
self.left = None # 左子树
self.right = None # 右子树
class C45DecisionTree:
def __init__(self, min_samples_split=2, max_depth=float('inf')):
self.min_samples_split = min_samples_split # 分裂所需的最小样本数
self.max_depth = max_depth # 树的最大深度
def _calculate_entropy(self, y):
classes, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
entropy = -np.sum(probabilities * np.log2(probabilities))
return entropy
def _calculate_information_gain(self, X, y, feature, threshold):
left_mask = X[:, feature] <= threshold
right_mask = X[:, feature] > threshold
left_entropy = self._calculate_entropy(y[left_mask])
right_entropy = self._calculate_entropy(y[right_mask])
parent_entropy = self._calculate_entropy(y)
information_gain = parent_entropy - (len(y[left_mask]) / len(y) * left_entropy) - (len(y[right_mask]) / len(y) * right_entropy)
return information_gain
def _split(self, X, y):
best_gain = 0
best_feature = None
best_threshold = None
for feature in range(X.shape[1]):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
gain = self._calculate_information_gain(X, y, feature, threshold)
if gain > best_gain:
best_gain = gain
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold
def _build_tree(self, X, y, depth):
if len(np.unique(y)) == 1: # 所有样本属于同一类别
return Node(label=y[0])
if len(y) < self.min_samples_split or depth == self.max_depth: # 达到停止分裂的条件
classes, counts = np.unique(y, return_counts=True)
label = classes[np.argmax(counts)]
return Node(label=label)
feature, threshold = self._split(X, y)
left_mask = X[:, feature] <= threshold
right_mask = X[:, feature] > threshold
left_subtree = self._build_tree(X[left_mask], y[left_mask], depth + 1)
right_subtree = self._build_tree(X[right_mask], y[right_mask], depth + 1)
return Node(feature=feature, threshold=threshold, left=left_subtree, right=right_subtree)
def fit(self, X, y):
self.root = self._build_tree(X, y, depth=0)
def _predict_single(self, x, node):
if node.label is not None: # 叶节点
return node.label
if x[node.feature] <= node.threshold:
return self._predict_single(x, node.left)
else:
return self._predict_single(x, node.right)
def predict(self, X):
predictions = []
for x in X:
prediction = self._predict_single(x, self.root)
predictions.append(prediction)
return np.array(predictions)
```
这段代码实现了一个简单的C4.5决策树模型,包括节点类`Node`和决策树类`C45DecisionTree`。`fit`方法用于训练模型,`predict`方法用于预测新样本的类别。
阅读全文