python实现iris数据集的id3算法决策树
时间: 2023-11-25 20:08:48 浏览: 57
好的,以下是Python实现iris数据集的ID3算法决策树的示例代码:
```python
import pandas as pd
import numpy as np
from math import log2
# 读取数据集
data = pd.read_csv('iris.csv')
# 划分训练集和测试集
train_data = data.sample(frac=0.8, random_state=0)
test_data = data.drop(train_data.index)
# 定义ID3算法决策树类
class ID3DecisionTree:
def __init__(self, max_depth):
self.max_depth = max_depth
def fit(self, data, targets, features):
self.tree = self.build_tree(data, targets, features, depth=0)
def predict(self, data):
predictions = []
for _, row in data.iterrows():
predictions.append(self.traverse_tree(row, self.tree))
return predictions
def build_tree(self, data, targets, features, depth):
# 如果只有一种标签,则返回叶子节点
if len(set(targets)) == 1:
return {'label': targets.iloc[0]}
# 如果没有特征可用,则返回叶子节点,标签为最常见的标签
if not features:
return {'label': targets.value_counts().idxmax()}
# 如果达到最大深度,则返回叶子节点,标签为最常见的标签
if depth >= self.max_depth:
return {'label': targets.value_counts().idxmax()}
# 计算信息增益
best_feature, best_gain = None, -1
for feature in features:
gain = self.information_gain(data, targets, feature)
if gain > best_gain:
best_feature, best_gain = feature, gain
# 如果最好的特征的信息增益为0,则返回叶子节点,标签为最常见的标签
if best_gain == 0:
return {'label': targets.value_counts().idxmax()}
# 构建决策树
tree = {'feature': best_feature, 'children': {}}
features.remove(best_feature)
for value in data[best_feature].unique():
sub_data = data[data[best_feature] == value]
sub_targets = targets.loc[sub_data.index]
if sub_data.empty:
tree['children'][value] = {'label': targets.value_counts().idxmax()}
else:
tree['children'][value] = self.build_tree(sub_data, sub_targets, features, depth+1)
return tree
def information_gain(self, data, targets, feature):
# 计算数据集的熵
entropy = self.entropy(targets)
# 计算特征的熵和条件熵
feature_entropy, conditional_entropy = 0, 0
for value in data[feature].unique():
sub_targets = targets.loc[data[data[feature] == value].index]
probability = len(sub_targets) / len(targets)
feature_entropy += -probability * self.entropy(sub_targets)
conditional_entropy += probability * self.entropy(sub_targets)
# 计算信息增益
return entropy - conditional_entropy
def entropy(self, targets):
# 计算给定标签的熵
entropy = 0
for _, count in targets.value_counts().iteritems():
probability = count / len(targets)
entropy += -probability * log2(probability)
return entropy
def traverse_tree(self, data, tree):
# 遍历决策树,返回预测的标签
if 'label' in tree:
return tree['label']
feature = tree['feature']
value = data[feature]
if value not in tree['children']:
return tree['children'][data[feature]]['label']
return self.traverse_tree(data, tree['children'][value])
# 定义特征和标签
features = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
target = 'species'
# 初始化决策树
tree = ID3DecisionTree(max_depth=3)
# 训练决策树
tree.fit(train_data[features], train_data[target], features)
# 预测测试集
predictions = tree.predict(test_data[features])
# 计算准确率
accuracy = sum(predictions == test_data[target]) / len(test_data)
print('Accuracy:', accuracy)
```
注意,这个示例代码是针对iris数据集的,如果要应用到其他数据集,需要根据数据集的特点进行相应的修改。