用python写一个决策树代码带注释
时间: 2024-05-07 18:22:35 浏览: 76
决策树算法Python代码实现,ipynb格式
5星 · 资源好评率100%
以下是用 Python 实现决策树的代码,附有注释:
```python
from math import log2
from collections import Counter
# 计算给定数据集的信息熵
def calc_entropy(data):
n = len(data)
label_counts = Counter([d[-1] for d in data])
entropy = 0.0
for label in label_counts:
prob = label_counts[label] / n
entropy -= prob * log2(prob)
return entropy
# 划分数据集,返回指定特征值的子集
def split_data(data, feature, value):
ret_data = []
for d in data:
if d[feature] == value:
reduced_d = d[:feature]
reduced_d.extend(d[feature+1:])
ret_data.append(reduced_d)
return ret_data
# 根据信息增益选择最佳特征
def choose_best_feature(data):
num_features = len(data[0]) - 1
base_entropy = calc_entropy(data)
best_info_gain = 0.0
best_feature = -1
for i in range(num_features):
feat_list = [d[i] for d in data]
unique_vals = set(feat_list)
new_entropy = 0.0
for value in unique_vals:
sub_data = split_data(data, i, value)
prob = len(sub_data) / float(len(data))
new_entropy += prob * calc_entropy(sub_data)
info_gain = base_entropy - new_entropy
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature = i
return best_feature
# 返回出现次数最多的类别
def majority_class(class_list):
class_count = Counter(class_list)
sorted_class_count = sorted(class_count.items(), key=lambda x: x[1], reverse=True)
return sorted_class_count[0][0]
# 创建决策树
def create_decision_tree(data, labels):
class_list = [d[-1] for d in data]
# 如果数据集中的所有实例都属于同一个类别,则返回该类别
if class_list.count(class_list[0]) == len(class_list):
return class_list[0]
# 如果数据集中没有特征可用于划分,则返回出现次数最多的类别
if len(data[0]) == 1:
return majority_class(class_list)
# 选择最佳特征进行划分
best_feat = choose_best_feature(data)
best_feat_label = labels[best_feat]
decision_tree = {best_feat_label: {}}
del(labels[best_feat])
feat_values = [d[best_feat] for d in data]
unique_vals = set(feat_values)
for value in unique_vals:
sub_labels = labels[:]
decision_tree[best_feat_label][value] = create_decision_tree(
split_data(data, best_feat, value), sub_labels)
return decision_tree
```
使用示例:
```python
data = [
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
labels = ['no surfacing', 'flippers']
tree = create_decision_tree(data, labels)
print(tree)
```
输出:
```
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
```
阅读全文