帮我写一个id3算法的代码
时间: 2023-08-28 13:26:56 浏览: 71
好的,下面是一个简单的Python实现ID3算法的代码,希望对你有帮助。
```python
import math
import numpy as np
class DecisionTree:
def __init__(self):
self.tree = {}
def fit(self, X, Y, feature_names):
self.feature_names = feature_names
self.tree = self._build_tree(X, Y, feature_names)
def predict(self, X):
pred = []
for row in X:
pred.append(self._predict(row, self.tree))
return np.array(pred)
def _build_tree(self, X, Y, feature_names):
n_samples, n_features = X.shape
n_classes = len(set(Y))
# Stop if there is only one class in the data
if len(set(Y)) == 1:
return Y[0]
# Stop if there are no more features to split on
if n_features == 0:
return self._most_common_class(Y)
# Calculate the information gain for each feature
information_gains = []
for feature_index in range(n_features):
ig = self._information_gain(X, Y, feature_index)
information_gains.append(ig)
# Select the feature with the highest information gain
best_feature_index = np.argmax(information_gains)
best_feature = feature_names[best_feature_index]
# Initialize the tree
tree = {best_feature: {}}
# Remove the feature with the highest information gain
feature_names = feature_names[:best_feature_index] + feature_names[best_feature_index+1:]
X = np.delete(X, best_feature_index, axis=1)
# Recursively build the tree
for value in set(X[:, best_feature_index]):
value_mask = X[:, best_feature_index] == value
subtree = self._build_tree(X[value_mask], Y[value_mask], feature_names)
tree[best_feature][value] = subtree
return tree
def _predict(self, row, tree):
for feature, subtree in tree.items():
value = row[self.feature_names.index(feature)]
if value in subtree:
return self._predict(row, subtree[value])
return self._most_common_class(tree.values())
def _most_common_class(self, Y):
return max(set(Y), key=Y.count)
def _entropy(self, Y):
entropy = 0
n_samples = len(Y)
for c in set(Y):
p = Y.count(c) / n_samples
entropy -= p * math.log2(p)
return entropy
def _information_gain(self, X, Y, feature_index):
n_samples, n_features = X.shape
feature_values = X[:, feature_index]
feature_entropy = 0
# Calculate the entropy for the feature
for feature_value in set(feature_values):
value_mask = feature_values == feature_value
value_Y = Y[value_mask]
value_entropy = self._entropy(value_Y)
p = len(value_Y) / n_samples
feature_entropy += p * value_entropy
# Calculate the information gain
base_entropy = self._entropy(Y)
information_gain = base_entropy - feature_entropy
return information_gain
```
注:这只是一个简单的实现,可能无法处理一些复杂的数据集,实际应用中可能需要进行一些优化和改进。
阅读全文