python实现ID3
时间: 2023-12-04 16:55:14 浏览: 60
ID3(Iterative Dichotomizer 3)算法是一种基于信息增益的决策树算法,用于分类和回归问题。下面是Python实现ID3算法的步骤:
1. 导入必要的库:
```python
import pandas as pd
import numpy as np
import math
```
2. 定义计算信息熵的函数:
```python
def entropy(target_col):
elements, counts = np.unique(target_col, return_counts=True)
entropy = np.sum([(-counts[i]/np.sum(counts)) * np.log2(counts[i]/np.sum(counts)) for i in range(len(elements))])
return entropy
```
3. 定义计算信息增益的函数:
```python
def info_gain(data, split_attribute_name, target_name="class"):
total_entropy = entropy(data[target_name])
vals, counts = np.unique(data[split_attribute_name], return_counts=True)
weighted_entropy = np.sum([(counts[i]/np.sum(counts)) * entropy(data.where(data[split_attribute_name]==vals[i]).dropna()[target_name]) for i in range(len(vals))])
information_gain = total_entropy - weighted_entropy
return information_gain
```
4. 定义获取最大信息增益的属性函数:
```python
def get_best_attribute(data):
info_gains = []
for feature in data.columns[:-1]:
info_gains.append(info_gain(data, feature))
best_feature_index = np.argmax(info_gains)
best_feature = data.columns[best_feature_index]
return best_feature
```
5. 定义ID3算法函数:
```python
def id3(data, original_data, features, target_attribute_name="class", parent_node_class=None):
if len(np.unique(data[target_attribute_name])) <= 1:
return np.unique(data[target_attribute_name])[0]
elif len(data)==0:
return np.unique(original_data[target_attribute_name])[np.argmax(np.unique(original_data[target_attribute_name], return_counts=True)[1])]
elif len(features) == 0:
return parent_node_class
else:
parent_node_class = np.unique(data[target_attribute_name])[np.argmax(np.unique(data[target_attribute_name], return_counts=True)[1])]
best_feature = get_best_attribute(data)
tree = {best_feature:{}}
features = [i for i in features if i != best_feature]
for value in np.unique(data[best_feature]):
sub_data = data.where(data[best_feature] == value).dropna()
subtree = id3(sub_data, data, features, target_attribute_name, parent_node_class)
tree[best_feature][value] = subtree
return(tree)
```
6. 定义预测函数:
```python
def predict(query, tree, default = 'default'):
for key in list(query.keys()):
if key in list(tree.keys()):
try:
result = tree[key][query[key]]
except:
return default
result = tree[key][query[key]]
if isinstance(result, dict):
return predict(query, result)
else:
return result
```
使用ID3算法构建决策树并预测样本的示例代码:
```python
data = pd.read_csv('example_data.csv')
features = list(data.columns[:-1])
tree = id3(data, data, features)
query = {'A': 1, 'B': 1, 'C': 0}
print(predict(query, tree))
```
以上就是使用Python实现ID3算法的步骤,可以用于分类和回归问题的决策树构建和预测。
阅读全文