python写ID3决策树代码
时间: 2023-09-02 12:11:45 浏览: 42
以下是Python实现ID3决策树的代码,其中包括了数据预处理、决策树构建和决策树预测三个部分。
```python
import numpy as np
import pandas as pd
import math
class Node:
def __init__(self, feature=None, label=None):
self.feature = feature
self.label = label
self.children = {}
class DecisionTree:
def __init__(self):
self.root = None
def fit(self, X, y):
self.root = self.build_tree(X, y)
def predict(self, X):
return np.array([self.traverse_tree(x, self.root) for x in X])
def build_tree(self, X, y):
if len(set(y)) == 1:
return Node(label=y[0])
if len(X.columns) == 0:
return Node(label=self.get_most_common_label(y))
best_feature = self.get_best_feature(X, y)
root = Node(feature=best_feature)
for value in set(X[best_feature]):
X_sub, y_sub = self.split_data(X, y, best_feature, value)
child = self.build_tree(X_sub, y_sub)
root.children[value] = child
return root
def traverse_tree(self, x, node):
if node.label is not None:
return node.label
feature_value = x[node.feature]
if feature_value in node.children:
next_node = node.children[feature_value]
else:
next_node = Node(label=self.get_most_common_label(y))
return self.traverse_tree(x, next_node)
def get_best_feature(self, X, y):
entropy = self.get_entropy(y)
max_info_gain = 0
best_feature = None
for feature in X.columns:
feature_values = set(X[feature])
feature_entropy = 0
for value in feature_values:
X_sub, y_sub = self.split_data(X, y, feature, value)
weight = len(X_sub) / len(X)
feature_entropy += weight * self.get_entropy(y_sub)
info_gain = entropy - feature_entropy
if info_gain > max_info_gain:
max_info_gain = info_gain
best_feature = feature
return best_feature
def get_entropy(self, y):
total = len(y)
counts = np.bincount(y)
probabilities = counts / total
entropy = -np.sum([p * math.log2(p) for p in probabilities if p > 0])
return entropy
def split_data(self, X, y, feature, value):
X_sub = X[X[feature] == value].drop(columns=[feature])
y_sub = y[X[feature] == value]
return X_sub, y_sub
def get_most_common_label(self, y):
return np.bincount(y).argmax()
```
以上是一个基于ID3决策树的Python实现,可以用于分类问题。请注意,代码中需要传入的数据X和y需要是pandas的DataFrame和Series类型。