请给出C4.5算法的代码,并实现可视化
时间: 2023-06-13 13:07:31 浏览: 92
C4.5算法是一种决策树算法,它通过对数据集进行划分,构建决策树,并对新数据进行分类。以下是Python实现的C4.5算法代码:
```python
import pandas as pd
import numpy as np
import math
class C45:
def __init__(self, eps=1e-4):
self.eps = eps
self.tree = {}
def calc_entropy(self, y):
entropy = 0
for i in np.unique(y):
p = len(y[y == i]) / len(y)
entropy += -p * math.log(p, 2)
return entropy
def calc_cond_entropy(self, y, x):
cond_entropy = 0
for i in np.unique(x):
p = len(x[x == i]) / len(x)
cond_entropy += p * self.calc_entropy(y[x == i])
return cond_entropy
def calc_info_gain_ratio(self, y, x):
info_gain = self.calc_entropy(y) - self.calc_cond_entropy(y, x)
split_info = -sum([(len(x[x == i]) / len(x)) * math.log((len(x[x == i]) / len(x)), 2) for i in np.unique(x)])
if split_info == 0:
return 0
return info_gain / split_info
def fit(self, X, y):
self.tree = self._fit(X, y, list(range(X.shape[1])), {})
def _fit(self, X, y, features, tree):
if len(np.unique(y)) == 1:
return np.unique(y)[0]
if len(features) == 0:
return np.bincount(y).argmax()
max_gain_ratio = -1
best_feature = None
for feature in features:
gain_ratio = self.calc_info_gain_ratio(y, X[:, feature])
if gain_ratio > max_gain_ratio:
max_gain_ratio = gain_ratio
best_feature = feature
if max_gain_ratio < self.eps:
return np.bincount(y).argmax()
sub_features = [f for f in features if f != best_feature]
tree = {best_feature: {}}
for i in np.unique(X[:, best_feature]):
X_sub = X[X[:, best_feature] == i]
y_sub = y[X[:, best_feature] == i]
tree[best_feature][i] = self._fit(X_sub, y_sub, sub_features, {})
return tree
def predict(self, X):
return np.array([self._predict(x, self.tree) for x in X])
def _predict(self, x, tree):
if isinstance(tree, dict):
return self._predict(x, tree[x.argmax()])
else:
return tree
def plot_tree(tree, feature_names, class_names, filename):
import pydotplus
from IPython.display import Image
from sklearn.externals.six import StringIO
dot_data = StringIO()
dot_data.write('digraph Tree {\n')
dot_data.write('node [shape=box] ;\n')
def traverse(node, parent, i):
for k, v in node.items():
if isinstance(v, dict):
i += 1
dot_data.write(f'{parent} -> {i} [label="{feature_names[k]}"] ;\n')
dot_data.write(f'{i} [label="{feature_names[k]}"] ;\n')
i = traverse(v, i, i)
else:
dot_data.write(f'{parent} -> {i+1} [label="{class_names[v]}"] ;\n')
dot_data.write(f'{i+1} [label="{class_names[v]}"] ;\n')
i += 1
return i
traverse(tree, 0, 0)
dot_data.write('}\n')
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_png(filename)
return Image(graph.create_png())
```
代码中包含了一个可视化函数`plot_tree`,它可以将决策树可视化,方便我们理解决策树的构建过程。使用方法如下:
```python
model = C45()
model.fit(X_train, y_train)
plot_tree(model.tree, feature_names, class_names, 'tree.png')
```
其中,`X_train`是训练集特征矩阵,`y_train`是训练集标签向量,`feature_names`是特征名称列表,`class_names`是类别名称列表,`tree.png`是保存决策树可视化结果的文件名。
阅读全文