id3算法python实现鸢尾花
时间: 2023-09-30 13:05:41 浏览: 185
好的,id3算法是一种决策树算法,可以用于分类问题。以下是使用Python实现id3算法来对鸢尾花数据集进行分类的示例代码:
首先,我们需要导入必要的库和鸢尾花数据集:
```python
import pandas as pd
from sklearn.datasets import load_iris
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['target'] = iris.target
```
接下来,我们需要定义一个函数来计算给定数据集的熵:
```python
import math
def entropy(target_col):
elements, counts = np.unique(target_col, return_counts=True)
entropy = np.sum([(-counts[i]/np.sum(counts)) * math.log2(counts[i]/np.sum(counts)) for i in range(len(elements))])
return entropy
```
然后,我们需要编写一个函数来计算每个特征的信息增益:
```python
def info_gain(data, split_attribute_name, target_name="target"):
total_entropy = entropy(data[target_name])
vals, counts= np.unique(data[split_attribute_name], return_counts=True)
weighted_entropy = np.sum([(counts[i]/np.sum(counts)) * entropy(data.where(data[split_attribute_name]==vals[i]).dropna()[target_name]) for i in range(len(vals))])
info_gain = total_entropy - weighted_entropy
return info_gain
```
现在,我们可以开始构建id3决策树。首先,我们需要编写一个函数来选择最佳划分特征:
```python
def id3(data, original_data, features, target_attribute_name="target", parent_node_class = None):
# 如果所有数据属于同一类别,则返回该类别
if len(np.unique(data[target_attribute_name])) <= 1:
return np.unique(data[target_attribute_name])[0]
# 如果数据集为空,则返回父节点中出现最频繁的类别
elif len(data)==0:
return np.unique(original_data[target_attribute_name])[np.argmax(np.unique(original_data[target_attribute_name],return_counts=True)[1])]
# 如果没有特征可用于分割数据集,则返回父节点中出现最频繁的类别
elif len(features) ==0:
return parent_node_class
# 如果数据集非空且存在可以使用的特征,则选择最佳划分特征
else:
parent_node_class = np.unique(data[target_attribute_name])[np.argmax(np.unique(data[target_attribute_name],return_counts=True)[1])]
item_values = [info_gain(data,feature,target_attribute_name) for feature in features]
best_feature_index = np.argmax(item_values)
best_feature = features[best_feature_index]
tree = {best_feature:{}}
# 从特征列表中删除已选择的最佳特征
features = [i for i in features if i != best_feature]
# 创建子树
for value in np.unique(data[best_feature]):
sub_data = data.where(data[best_feature] == value).dropna()
subtree = id3(sub_data,original_data,features,target_attribute_name,parent_node_class)
tree[best_feature][value] = subtree
return(tree)
```
最后,我们可以使用以下代码来训练决策树并对测试集进行预测:
```python
from sklearn.model_selection import train_test_split
# 将数据集拆分为训练集和测试集
train, test = train_test_split(df, test_size=0.2)
# 训练决策树
features = df.columns[:-1].tolist()
tree = id3(train,train,features)
# 对测试集进行预测
def predict(query,tree,default = 1):
for key in list(query.keys()):
if key in list(tree.keys()):
try:
result = tree[key][query[key]]
except:
return default
result = tree[key][query[key]]
if isinstance(result,dict):
return predict(query,result)
else:
return result
# 评估模型的准确性
test['predicted'] = test.apply(predict, axis=1, args=(tree, -1))
accuracy = sum(test["predicted"] == test["target"]) / len(test)
print(f"Accuracy: {accuracy}")
```
以上就是使用Python实现id3算法对鸢尾花数据集进行分类的完整代码。
阅读全文