用python实现id3算法
时间: 2023-09-17 21:13:39 浏览: 35
好的,下面是一个简单的Python实现ID3算法的例子:
```
import pandas as pd
import numpy as np
import math
class Node:
def __init__(self, feature=None, label=None, child=None):
self.feature = feature
self.label = label
self.child = child
def entropy(data):
target = data[:, -1]
_, count = np.unique(target, return_counts=True)
p = count / count.sum()
return -sum(p * np.log2(p))
def split_data(data, feature_index):
feature = data[:, feature_index]
values, count = np.unique(feature, return_counts=True)
res = []
for value in values:
sub_data = data[feature == value]
res.append(sub_data)
return res
def info_gain(data, feature_index):
feature = data[:, feature_index]
h_d = entropy(data)
h_d_a = 0
for sub_data in split_data(data, feature_index):
h_d_a += (sub_data.shape[0]/data.shape[0]) * entropy(sub_data)
return h_d - h_d_a
def majority_label(data):
target = data[:, -1]
values, count = np.unique(target, return_counts=True)
max_index = count.argmax()
return values[max_index]
def build_tree(data, features):
if len(np.unique(data[:, -1])) == 1:
label = data[0, -1]
return Node(label=label)
if len(features) == 0:
label = majority_label(data)
return Node(label=label)
info_gain_list = [info_gain(data, feature_index) for feature_index in features]
max_index = np.array(info_gain_list).argmax()
max_feature = features[max_index]
root = Node(feature=max_feature)
for sub_data in split_data(data, max_feature):
child = build_tree(sub_data, np.delete(features, max_index))
root.child.append(child)
return root
def predict(root, x):
if root.label is not None:
return root.label
for child in root.child:
if child.feature == x[root.feature]:
return predict(child, x)
if __name__ == '__main__':
data = pd.read_csv('data.csv').values
features = [i for i in range(data.shape[1]-1)]
root = build_tree(data, features)
x = np.array([1, 'S'])
print(predict(root, x))
```
这里使用了pandas读取数据,需要将数据存为csv文件。其中Node类表示决策树的节点,包含特征、标签和子节点。entropy函数计算数据集的熵,split_data函数根据特征划分数据集,info_gain函数计算信息增益,majority_label函数返回数据集中出现最多的标签,build_tree函数递归构建决策树,predict函数预测新的数据的标签。