用Python写一个西瓜书id3决策树模型
时间: 2024-03-14 09:44:47 浏览: 104
好的,以下是用Python实现的西瓜书id3决策树模型的代码:
```python
import math
import pandas as pd
class Node:
'''
决策树节点类
'''
def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
self.col = col # 待检验的判断条件所对应的列索引值
self.value = value # 为了使结果为True,当前列必须匹配的值
self.results = results # 存储叶节点上的结果,是一个字典形式,键为类别,值为次数
self.tb = tb # 左子树
self.fb = fb # 右子树
def load_data():
'''
加载西瓜数据集,返回特征数据和标签
'''
data = pd.read_csv('watermelon.csv')
return data.iloc[:, 1:-1], data.iloc[:, -1]
def calc_entropy(labels):
'''
计算数据集的熵
'''
total = len(labels)
counts = {}
for label in labels:
if label not in counts:
counts[label] = 0
counts[label] += 1
entropy = 0.0
for key in counts:
p = counts[key] / total
entropy -= p * math.log2(p)
return entropy
def split_data(data, labels, col, value):
'''
根据给定特征划分数据集
'''
tb_rows, fb_rows = [], []
for i in range(len(data)):
row = list(data.iloc[i])
if row[col] == value:
tb_rows.append(row + [labels[i]])
else:
fb_rows.append(row + [labels[i]])
return pd.DataFrame(tb_rows, columns=data.columns.tolist() + ['label']), pd.DataFrame(fb_rows, columns=data.columns.tolist() + ['label'])
def build_tree(data, labels):
'''
构建决策树
'''
if len(labels) == 0:
return Node()
current_entropy = calc_entropy(labels)
best_gain = 0.0
best_criteria = None
best_sets = None
feature_num = len(data.columns)
for col in range(feature_num):
column_values = set(data.iloc[:, col])
for value in column_values:
tb_data, fb_data = split_data(data, labels, col, value)
p = len(tb_data) / len(data)
gain = current_entropy - p * calc_entropy(tb_data['label']) - (1 - p) * calc_entropy(fb_data['label'])
if gain > best_gain and len(tb_data) > 0 and len(fb_data) > 0:
best_gain = gain
best_criteria = (col, value)
best_sets = (tb_data, fb_data)
if best_gain > 0:
tb = build_tree(best_sets[0], best_sets[0]['label'])
fb = build_tree(best_sets[1], best_sets[1]['label'])
return Node(col=best_criteria[0], value=best_criteria[1], tb=tb, fb=fb)
else:
return Node(results={label: len([label for label in labels if label == '是']), '否': len([label for label in labels if label == '否'])})
def classify(sample, tree):
'''
使用决策树对单个样本进行分类
'''
if tree.results is not None:
return tree.results
else:
v = sample[tree.col]
branch = None
if v == tree.value:
branch = tree.tb
else:
branch = tree.fb
return classify(sample, branch)
def predict(data, tree):
'''
对数据集进行分类
'''
return [classify(list(data.iloc[i]), tree) for i in range(len(data))]
data, labels = load_data()
tree = build_tree(data, labels)
print(predict(data, tree))
```
这里使用了pandas库来读取数据集,需要将数据集放在与代码文件相同的目录下,并命名为watermelon.csv。函数load_data返回的是特征数据和标签,分别是DataFrame类型和Series类型。函数build_tree实现了id3算法,返回构建好的决策树。函数classify用于对单个样本进行分类,函数predict用于对整个数据集进行分类。
阅读全文