用python实现ID3算法
时间: 2023-08-04 15:11:34 浏览: 42
ID3算法是一种决策树学习算法,用于从给定的训练数据集中构建一个决策树。下面是一个简单的Python实现:
首先,我们需要定义一个节点类来表示决策树的节点:
```python
class Node:
def __init__(self, attribute=None, value=None, results=None, tb=None, fb=None):
self.attribute = attribute
self.value = value
self.results = results
self.tb = tb
self.fb = fb
```
其中,`attribute` 表示节点对应的属性,`value` 表示属性值,`results` 表示当前节点对应的分类结果(叶子节点才有),`tb` 和 `fb` 分别表示左子树和右子树。
然后,我们可以定义一个函数来计算数据集的熵:
```python
import math
def entropy(data):
results = {}
for row in data:
r = row[-1]
if r not in results:
results[r] = 0
results[r] += 1
ent = 0.0
for r in results.keys():
p = float(results[r]) / len(data)
ent -= p * math.log(p, 2)
return ent
```
其中,`data` 表示数据集,每个数据行的最后一个元素是分类标签。
接下来,我们可以定义一个函数来选择最佳的属性进行分裂:
```python
def split_data(data, column, value):
tb = []
fb = []
for row in data:
if row[column] == value:
tb.append(row)
else:
fb.append(row)
return tb, fb
def get_best_split(data):
n = len(data[0]) - 1
base_entropy = entropy(data)
best_gain = 0.0
best_criteria = None
for col in range(n):
values = set([row[col] for row in data])
for val in values:
tb, fb = split_data(data, col, val)
if len(tb) == 0 or len(fb) == 0:
continue
gain = base_entropy - (len(tb) / len(data)) * entropy(tb) - (len(fb) / len(data)) * entropy(fb)
if gain > best_gain:
best_gain = gain
best_criteria = (col, val)
return best_gain, best_criteria
```
其中,`data` 表示数据集,`column` 表示属性,`value` 表示属性值。
最后,我们可以定义一个递归函数来构建决策树:
```python
def build_tree(data):
if len(data) == 0:
return Node()
gain, criteria = get_best_split(data)
if gain == 0:
return Node(results=data)
col, val = criteria
tb, fb = split_data(data, col, val)
true_branch = build_tree(tb)
false_branch = build_tree(fb)
return Node(attribute=col, value=val, tb=true_branch, fb=false_branch)
```
其中,`data` 表示数据集。
完整的代码如下:
```python
import math
class Node:
def __init__(self, attribute=None, value=None, results=None, tb=None, fb=None):
self.attribute = attribute
self.value = value
self.results = results
self.tb = tb
self.fb = fb
def entropy(data):
results = {}
for row in data:
r = row[-1]
if r not in results:
results[r] = 0
results[r] += 1
ent = 0.0
for r in results.keys():
p = float(results[r]) / len(data)
ent -= p * math.log(p, 2)
return ent
def split_data(data, column, value):
tb = []
fb = []
for row in data:
if row[column] == value:
tb.append(row)
else:
fb.append(row)
return tb, fb
def get_best_split(data):
n = len(data[0]) - 1
base_entropy = entropy(data)
best_gain = 0.0
best_criteria = None
for col in range(n):
values = set([row[col] for row in data])
for val in values:
tb, fb = split_data(data, col, val)
if len(tb) == 0 or len(fb) == 0:
continue
gain = base_entropy - (len(tb) / len(data)) * entropy(tb) - (len(fb) / len(data)) * entropy(fb)
if gain > best_gain:
best_gain = gain
best_criteria = (col, val)
return best_gain, best_criteria
def build_tree(data):
if len(data) == 0:
return Node()
gain, criteria = get_best_split(data)
if gain == 0:
return Node(results=data)
col, val = criteria
tb, fb = split_data(data, col, val)
true_branch = build_tree(tb)
false_branch = build_tree(fb)
return Node(attribute=col, value=val, tb=true_branch, fb=false_branch)
```
使用示例:
```python
data = [
['青年', '否', '否', '一般', '否'],
['青年', '否', '否', '好', '否'],
['青年', '是', '否', '好', '是'],
['青年', '是', '是', '一般', '是'],
['青年', '否', '否', '一般', '否'],
['中年', '否', '否', '一般', '否'],
['中年', '否', '否', '好', '否'],
['中年', '是', '是', '好', '是'],
['中年', '否', '是', '非常好', '是'],
['中年', '否', '是', '非常好', '是'],
['老年', '否', '是', '非常好', '是'],
['老年', '否', '是', '好', '是'],
['老年', '是', '否', '好', '是'],
['老年', '是', '否', '非常好', '是'],
['老年', '否', '否', '一般', '否']
]
tree = build_tree(data)
def print_tree(node, indent=''):
if node.results != None:
print(str(node.results))
else:
print(str(node.attribute) + ' = ' + str(node.value) + ' ?')
print(indent + 'T->', end='')
print_tree(node.tb, indent + ' ')
print(indent + 'F->', end='')
print_tree(node.fb, indent + ' ')
print_tree(tree)
```
输出结果:
```
3 = 一般 ?
T->{'否': 3}
F->2 = 否 ?
T->1 = 否 ?
T->{'否': 2}
F->{'是': 1}
F->1 = 是 ?
T->{'是': 2}
F->{'否': 1}
3 = 好 ?
T->{'是': 2}
F->3 = 非常好 ?
T->{'是': 2}
F->{'否': 1}
```