随机生成一组数据集,用id3算法进行训练(python代码)
时间: 2023-12-21 17:07:42 浏览: 60
好的,以下是一个简单的随机生成数据集并使用ID3算法进行训练的Python代码示例:
```python
import numpy as np
from collections import Counter
import math
class DecisionTree:
def __init__(self):
self.tree = {}
def entropy(self, y):
counter = Counter(y)
entropy = 0.0
for label in counter.keys():
prob = counter[label] / len(y)
entropy -= prob * math.log(prob, 2)
return entropy
def split_data(self, X, y, feature_index):
splits = {}
for idx, feature in enumerate(X[:, feature_index]):
if feature not in splits:
splits[feature] = ([], [])
splits[feature][0].append(idx)
splits[feature][1].append(y[idx])
return splits
def information_gain(self, X, y, feature_index):
splits = self.split_data(X, y, feature_index)
entropy_total = self.entropy(y)
entropy_feature = 0.0
for feature, (indices, sub_y) in splits.items():
prob = len(indices) / len(y)
entropy_feature += prob * self.entropy(sub_y)
return entropy_total - entropy_feature
def get_best_feature(self, X, y):
best_feature_index = -1
best_information_gain = 0.0
for feature_index in range(X.shape[1]):
information_gain = self.information_gain(X, y, feature_index)
if information_gain > best_information_gain:
best_information_gain = information_gain
best_feature_index = feature_index
return best_feature_index
def fit(self, X, y):
self.tree = self.build_tree(X, y)
def build_tree(self, X, y):
if len(set(y)) == 1:
return y[0]
if X.shape[1] == 0:
return Counter(y).most_common(1)[0][0]
best_feature_index = self.get_best_feature(X, y)
tree = {best_feature_index: {}}
splits = self.split_data(X, y, best_feature_index)
for feature, (indices, sub_y) in splits.items():
sub_X = X[indices]
sub_tree = self.build_tree(sub_X, sub_y)
tree[best_feature_index][feature] = sub_tree
return tree
def predict(self, X):
return [self.predict_one(x) for x in X]
def predict_one(self, x):
tree = self.tree
while isinstance(tree, dict):
feature_index = list(tree.keys())[0]
feature_value = x[feature_index]
tree = tree[feature_index][feature_value]
return tree
# 生成随机数据集
X = np.random.randint(2, size=(100, 5))
y = np.random.randint(2, size=(100,))
# 使用ID3算法进行训练
dt = DecisionTree()
dt.fit(X, y)
# 预测结果
pred = dt.predict(X)
print("Accuracy:", np.mean(pred == y))
```
代码中使用numpy库生成随机的二分类数据集,然后定义了一个DecisionTree类用于实现ID3算法。其中,entropy函数用于计算信息熵,split_data函数用于分割数据集,information_gain函数用于计算信息增益,get_best_feature函数用于选择最佳特征,fit函数用于训练模型,build_tree函数用于构建决策树,predict函数用于预测结果,predict_one函数用于递归预测单个样本的结果。最后输出预测准确率。
阅读全文