鸢尾花数据集ID3算法分类的python代码
时间: 2023-09-26 16:11:25 浏览: 169
以下是使用ID3算法对鸢尾花数据集进行分类的Python代码:
```
import pandas as pd
import math
# 读取数据集
data = pd.read_csv("iris.csv")
# 定义计算信息熵的函数
def entropy(data):
labels = data.iloc[:, -1]
counts = labels.value_counts()
probs = counts / len(labels)
return sum([-p * math.log(p, 2) for p in probs])
# 定义计算信息增益的函数
def info_gain(data, feature):
values = data[feature].unique()
entropy_before = entropy(data)
entropy_after = 0
for value in values:
subset = data[data[feature] == value]
entropy_after += len(subset) / len(data) * entropy(subset)
return entropy_before - entropy_after
# 定义递归构建决策树的函数
def build_tree(data, features):
labels = data.iloc[:, -1]
# 如果数据集中所有样本都属于同一类别,则返回该类别
if len(labels.unique()) == 1:
return labels.iloc[0]
# 如果没有特征可以用来分类,则返回数据集中样本数最多的类别
if len(features) == 0:
return labels.value_counts().idxmax()
# 选择信息增益最大的特征作为分类依据
info_gains = [(feature, info_gain(data, feature)) for feature in features]
best_feature, _ = max(info_gains, key=lambda x: x[1])
# 构建子树
tree = {best_feature: {}}
for value in data[best_feature].unique():
subset = data[data[best_feature] == value]
if len(subset) == 0:
tree[best_feature][value] = labels.value_counts().idxmax()
else:
tree[best_feature][value] = build_tree(subset.drop(best_feature, axis=1), features - {best_feature})
return tree
# 分割数据集为训练集和测试集
train_data = pd.concat([data.iloc[:40], data.iloc[50:90], data.iloc[100:140]], axis=0)
test_data = pd.concat([data.iloc[40:50], data.iloc[90:100], data.iloc[140:150]], axis=0)
train_data.reset_index(drop=True, inplace=True)
test_data.reset_index(drop=True, inplace=True)
# 构建决策树
features = set(train_data.columns[:-1])
tree = build_tree(train_data, features)
# 对测试集进行预测
def predict(tree, instance):
if isinstance(tree, str):
return tree
else:
feature = list(tree.keys())[0]
value = instance[feature]
return predict(tree[feature][value], instance)
predictions = [predict(tree, test_data.iloc[i]) for i in range(len(test_data))]
actuals = test_data.iloc[:, -1].tolist()
# 计算准确率
accuracy = sum([1 if predictions[i] == actuals[i] else 0 for i in range(len(predictions))]) / len(predictions)
print("Accuracy:", accuracy)
```
需要注意的是,该代码中使用了pandas库来读取和处理数据集,因此需要先安装该库。此外,该代码中使用了递归构建决策树的方法,因此对于较大的数据集可能会出现栈溢出等问题。
阅读全文