决策树ID3预测实例代码
时间: 2024-03-31 19:30:23 浏览: 16
决策树ID3是一种常用的分类算法,它通过计算信息增益来选择最优的特征进行分类。下面是一个简单的决策树ID3预测的实例代码:
```python
import numpy as np
def calc_entropy(data):
# 计算数据集的熵
labels = data[:, -1]
label_counts = np.unique(labels, return_counts=True)
probs = label_counts[1] / len(labels)
entropy = -np.sum(probs * np.log2(probs))
return entropy
def split_data(data, feature_index, feature_value):
# 根据特征和特征值划分数据集
mask = data[:, feature_index] == feature_value
return data[mask]
def choose_best_feature(data):
# 选择最优的特征
num_features = data.shape[1] - 1
base_entropy = calc_entropy(data)
best_info_gain = 0.0
best_feature_index = -1
for i in range(num_features):
feature_values = np.unique(data[:, i])
new_entropy = 0.0
for value in feature_values:
sub_data = split_data(data, i, value)
prob = len(sub_data) / len(data)
new_entropy += prob * calc_entropy(sub_data)
info_gain = base_entropy - new_entropy
if info_gain > best_info_gain:
best_info_gain = info_gain
best_feature_index = i
return best_feature_index
def majority_vote(labels):
# 多数表决,返回出现次数最多的类别
label_counts = np.unique(labels, return_counts=True)
max_count_index = np.argmax(label_counts[1])
return label_counts[0][max_count_index]
def create_decision_tree(data, features):
# 创建决策树
labels = data[:, -1]
if len(np.unique(labels)) == 1:
return labels[0]
if len(data[0]) == 1:
return majority_vote(labels)
best_feature_index = choose_best_feature(data)
best_feature = features[best_feature_index]
decision_tree = {best_feature: {}}
feature_values = np.unique(data[:, best_feature_index])
for value in feature_values:
sub_features = np.delete(features, best_feature_index)
sub_data = split_data(data, best_feature_index, value)
decision_tree[best_feature][value] = create_decision_tree(sub_data, sub_features)
return decision_tree
def predict(decision_tree, features, test_data):
# 使用决策树进行预测
root_feature = list(decision_tree.keys())[0]
root_value = test_data[features.index(root_feature)]
sub_tree = decision_tree[root_feature][root_value]
if isinstance(sub_tree, dict):
return predict(sub_tree, features, test_data)
else:
return sub_tree
# 示例数据集
data = np.array([
[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
])
# 特征列表
features = ['feature1', 'feature2']
# 创建决策树
decision_tree = create_decision_tree(data, features)
# 预测新样本
test_data = [1, 0]
prediction = predict(decision_tree, features, test_data)
print("预测结果:", prediction)
```
这段代码实现了决策树ID3算法的预测过程。首先,通过计算信息增益选择最优的特征进行分类,然后根据特征和特征值划分数据集,递归地创建决策树。最后,使用决策树进行预测。