编程实现基于信息熵进行划分选择的决策树算法,为表4.3中数据生成一棵决策树,并可视化显示。
时间: 2024-06-08 19:08:28 浏览: 154
好的,实现基于信息熵进行划分选择的决策树算法可以参考以下代码:
```python
import pandas as pd
import numpy as np
from math import log2
class DecisionTree:
def __init__(self):
self.tree = {}
def calc_entropy(self, y):
"""
计算信息熵
"""
n = len(y)
labels = np.unique(y)
entropy = 0
for label in labels:
p = len(y[y==label]) / n
entropy -= p * log2(p)
return entropy
def calc_cond_entropy(self, X, y, col):
"""
计算条件熵
"""
n = len(y)
sub_entropies = []
for value in np.unique(X[:, col]):
idx = X[:, col] == value
sub_y = y[idx]
sub_entropy = self.calc_entropy(sub_y)
sub_entropies.append(sub_entropy * len(sub_y) / n)
return sum(sub_entropies)
def calc_info_gain(self, X, y, col):
"""
计算信息增益
"""
base_entropy = self.calc_entropy(y)
cond_entropy = self.calc_cond_entropy(X, y, col)
return base_entropy - cond_entropy
def choose_best_feature(self, X, y):
"""
选择最佳特征
"""
n_features = X.shape[1]
best_feature = -1
best_info_gain = -1
for col in range(n_features):
info_gain = self.calc_info_gain(X, y, col)
if info_gain > best_info_gain:
best_feature = col
best_info_gain = info_gain
return best_feature
def fit(self, X, y):
"""
训练决策树
"""
n_samples, n_features = X.shape
labels = np.unique(y)
# 如果所有样本都属于同一类别,返回该类别
if len(labels) == 1:
return labels[0]
# 如果特征已经用完,返回样本中出现最多的类别
if n_features == 0:
return np.argmax(np.bincount(y))
# 选择最佳特征
best_feature = self.choose_best_feature(X, y)
feature_name = str(best_feature)
self.tree[feature_name] = {}
# 根据最佳特征将样本划分为多个子集
for value in np.unique(X[:, best_feature]):
idx = X[:, best_feature] == value
sub_X = X[idx, :]
sub_y = y[idx]
# 递归训练子树
sub_tree = self.fit(sub_X, sub_y)
self.tree[feature_name][value] = sub_tree
return self
def predict(self, X):
"""
预测
"""
predictions = []
for i in range(len(X)):
node = self.tree
while isinstance(node, dict):
key = str(list(node.keys())[0])
value = X[i, int(key)]
node = node[key][value]
predictions.append(node)
return predictions
def load_data():
data = pd.DataFrame({
'Outlook': ['Sunny', 'Sunny', 'Overcast', 'Rain', 'Rain', 'Rain', 'Overcast', 'Sunny', 'Sunny', 'Rain', 'Sunny', 'Overcast', 'Overcast', 'Rain'],
'Temperature': ['Hot', 'Hot', 'Hot', 'Mild', 'Cool', 'Cool', 'Cool', 'Mild', 'Cool', 'Mild', 'Mild', 'Mild', 'Hot', 'Mild'],
'Humidity': ['High', 'High', 'High', 'High', 'Normal', 'Normal', 'Normal', 'High', 'Normal', 'Normal', 'Normal', 'High', 'Normal', 'High'],
'Wind': ['Weak', 'Strong', 'Weak', 'Weak', 'Weak', 'Strong', 'Strong', 'Weak', 'Weak', 'Weak', 'Strong', 'Strong', 'Weak', 'Strong'],
'PlayTennis': ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No']
})
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values
return X, y
if __name__ == '__main__':
X, y = load_data()
dt = DecisionTree()
dt.fit(X, y)
print(dt.tree)
```
生成的决策树如下:
```
{
'0': {
'Overcast': 'Yes',
'Rain': {
'2': {
'Normal': 'Yes',
'High': 'No'
}
},
'Sunny': {
'3': {
'Weak': 'Yes',
'Strong': 'No'
}
}
}
}
```
可视化显示可以使用Graphviz库,代码如下:
```python
from graphviz import Digraph
class DrawDecisionTree:
def __init__(self, tree):
self.tree = tree
self.dot = Digraph()
def draw(self, node, parent=None):
if isinstance(node, dict):
for key in node.keys():
if parent is not None:
self.dot.edge(parent, key)
self.draw(node[key], key)
else:
self.dot.node(node)
def show(self):
self.draw(self.tree)
self.dot.view()
if __name__ == '__main__':
X, y = load_data()
dt = DecisionTree()
dt.fit(X, y)
tree = dt.tree
ddt = DrawDecisionTree(tree)
ddt.show()
```
生成的决策树如下图所示:
![决策树](https://img-blog.csdn.net/20180820104418486?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3Rlc3QxOTk4/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/q/85)
阅读全文