python实现id3决策树代码
时间: 2023-07-16 14:15:37 浏览: 113
以下是一个简单的Python实现ID3算法的决策树代码:
```python
import math
import pandas as pd
class Node:
def __init__(self):
self.children = {}
self.attribute = ""
self.value = ""
self.label = ""
def entropy(data):
n = len(data)
if n == 0:
return 0
num_classes = len(set(data))
counts = [0] * num_classes
for i in range(n):
for j, c in enumerate(set(data)):
if data[i] == c:
counts[j] += 1
break
p = [c / n for c in counts]
return -sum([pi * math.log2(pi) for pi in p if pi != 0])
def information_gain(data, attribute, label):
n = len(data)
subsets = {}
for i in range(n):
if data[i][attribute] not in subsets:
subsets[data[i][attribute]] = []
subsets[data[i][attribute]].append(data[i][label])
entropy_subsets = sum([entropy(subsets[k]) * len(subsets[k]) / n for k in subsets])
return entropy(data[label]) - entropy_subsets
def id3(data, attributes, label):
root = Node()
# Case 1: if all examples have the same label
if len(set(data[label])) == 1:
root.label = data[label][0]
return root
# Case 2: if attributes is empty
if not attributes:
root.label = max(set(data[label]), key=data[label].count)
return root
# Find the best attribute to split on
best_attr = max(attributes, key=lambda a: information_gain(data, a, label))
root.attribute = best_attr
# Create a branch for each value of the best attribute
subsets = {}
for i in range(len(data)):
if data[i][best_attr] not in subsets:
subsets[data[i][best_attr]] = []
subsets[data[i][best_attr]].append(i)
for k in subsets:
child = id3(
data.loc[subsets[k]].reset_index(drop=True),
[a for a in attributes if a != best_attr],
label
)
child.value = k
root.children[k] = child
return root
# Example usage
data = pd.DataFrame({
"Outlook": ["Sunny", "Sunny", "Overcast", "Rainy", "Rainy", "Rainy", "Overcast", "Sunny", "Sunny", "Rainy", "Sunny", "Overcast", "Overcast", "Rainy"],
"Temperature": ["Hot", "Hot", "Hot", "Mild", "Cool", "Cool", "Cool", "Mild", "Cool", "Mild", "Mild", "Mild", "Hot", "Mild"],
"Humidity": ["High", "High", "High", "High", "Normal", "Normal", "Normal", "High", "Normal", "Normal", "Normal", "High", "Normal", "High"],
"Windy": [False, True, False, False, False, True, True, False, False, False, True, True, False, True],
"Play": [False, False, True, True, True, False, True, False, True, True, True, True, True, False]
})
root = id3(data, ["Outlook", "Temperature", "Humidity", "Windy"], "Play")
```
阅读全文