c45决策树算法python实现
时间: 2023-07-20 21:10:44 浏览: 53
下面是一个简单的C4.5决策树算法的Python实现,仅供参考:
```python
import math
import pandas as pd
class C45DecisionTree:
def __init__(self, epsilon=0.1):
self.epsilon = epsilon
def fit(self, X, y):
self.decision_tree = self._build_tree(X, y)
def predict(self, X):
return [self._predict_one(row, self.decision_tree) for _, row in X.iterrows()]
def _build_tree(self, X, y):
# 如果所有的样本属于同一个类别,返回该类别作为叶子节点
if len(set(y)) == 1:
return {'label': y[0]}
# 如果没有特征可用,则返回样本中出现最多的类别作为叶子节点
if len(X.columns) == 0:
return {'label': y.value_counts().idxmax()}
# 计算每个特征的信息增益比
gains = []
for col in X.columns:
gain = self._information_gain_ratio(X[col], y)
gains.append((col, gain))
# 找到信息增益比最大的特征
best_feature, best_gain = max(gains, key=lambda x: x[1])
# 如果信息增益比小于阈值,返回样本中出现最多的类别作为叶子节点
if best_gain < self.epsilon:
return {'label': y.value_counts().idxmax()}
# 构建树
decision_tree = {'feature': best_feature}
for value in set(X[best_feature]):
X_sub = X[X[best_feature] == value].drop(best_feature, axis=1)
y_sub = y[X[best_feature] == value]
decision_tree[value] = self._build_tree(X_sub, y_sub)
return decision_tree
def _predict_one(self, row, tree):
if 'label' in tree:
return tree['label']
else:
value = row[tree['feature']]
sub_tree = tree[value]
return self._predict_one(row, sub_tree)
def _entropy(self, y):
value_counts = y.value_counts()
probabilities = value_counts / len(y)
return -sum(probabilities * probabilities.apply(math.log2))
def _information_gain_ratio(self, X, y):
# 计算信息增益
info_gain = self._entropy(y) - self._conditional_entropy(X, y)
# 计算分裂信息
split_info = -((X.value_counts() / len(X)).apply(math.log2) * (X.value_counts() / len(X))).sum()
# 计算信息增益比
if split_info == 0:
return 0
else:
return info_gain / split_info
def _conditional_entropy(self, X, y):
value_counts = X.value_counts()
probabilities = value_counts / len(X)
conditional_entropy = 0
for value in set(X):
y_sub = y[X == value]
conditional_entropy += probabilities[value] * self._entropy(y_sub)
return conditional_entropy
```
这个实现使用了Pandas数据框架来处理数据。其中,`X`是一个`DataFrame`对象,表示特征矩阵;`y`是一个`Series`对象,表示标签向量。`fit`方法用于训练模型,`predict`方法用于预测新的样本。在内部,`_build_tree`方法用于递归地构建决策树;`_predict_one`方法用于递归地预测单个样本。`_entropy`方法用于计算熵;`_information_gain_ratio`方法用于计算信息增益比;`_conditional_entropy`方法用于计算条件熵。