决策树id3算法python代码
时间: 2023-10-11 21:13:23 浏览: 129
以下是一个简单的决策树ID3算法的Python代码示例:
```python
import math
from collections import Counter
def entropy(data):
"""
计算数据集的熵
"""
n = len(data)
label_counts = Counter(data)
probs = [label_counts[label] / n for label in label_counts]
return -sum(p * math.log2(p) for p in probs)
def information_gain(data, split_attr, target_attr):
"""
计算信息增益
"""
original_entropy = entropy(data[target_attr])
n = len(data)
split_counts = Counter(data[split_attr])
split_entropy = sum(split_counts[split_val] / n * entropy(data[data[split_attr] == split_val][target_attr]) for split_val in split_counts)
return original_entropy - split_entropy
def id3(data, target_attr, attrs):
"""
ID3算法
"""
if len(set(data[target_attr])) == 1:
return data[target_attr].iloc[0]
if not attrs:
return Counter(data[target_attr]).most_common(1)[0][0]
best_attr = max(attrs, key=lambda attr: information_gain(data, attr, target_attr))
tree = {best_attr: {}}
for attr_val in set(data[best_attr]):
subtree = id3(data[data[best_attr] == attr_val].drop(best_attr, axis=1), target_attr, attrs - {best_attr})
tree[best_attr][attr_val] = subtree
return tree
```
其中,`data`是一个Pandas DataFrame,`target_attr`是目标属性列的名称,`attrs`是一个包含所有属性名称的集合。函数`entropy`计算数据集的熵,`information_gain`计算信息增益,`id3`是ID3算法的主要函数。函数返回一个字典,其中每个键是一个属性名称,对应的值是一个子树。
阅读全文