ID3决策树代码见CreatTree.py。请结合算法原理,通过认真阅读、单步调试等方法,掌握ID3决策树的代码实现,并根据自己的理解,对代码进行详细注释。
时间: 2024-02-01 14:13:01 浏览: 199
好的,下面我来详细讲解一下ID3决策树的代码实现,并添加注释来帮助理解。
```python
from math import log
import operator
# 计算给定数据集的香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet) # 数据集中实例的总数
labelCounts = {} # 创建一个空字典,用于存储每个类别出现的次数
# 对每组特征向量进行统计
for featVec in dataSet:
currentLabel = featVec[-1] # 每个实例的类别
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0 # 初始化香农熵
# 计算香农熵
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
# 创建测试数据集
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 1, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataSet, labels
# 按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = [] # 创建新的list对象
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis] # 去掉axis特征
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) - 1 # 特征数
baseEntropy = calcShannonEnt(dataSet) # 计算数据集的香农熵
bestInfoGain = 0.0 # 初始化信息增益
bestFeature = -1 # 初始化最优特征的索引值
for i in range(numFeatures): # 对每个特征循环
featList = [example[i] for example in dataSet] # 取得所有样本该特征的取值
uniqueVals = set(featList) # 创建set集合,元素不重复
newEntropy = 0.0 # 初始化经验条件熵
for value in uniqueVals: # 对每个特征划分一次数据集
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy # 信息增益
if (infoGain > bestInfoGain): # 比较信息增益
bestInfoGain = infoGain # 更新信息增益
bestFeature = i # 记录信息增益最大的特征的索引值
return bestFeature
# 多数表决规则,决定叶子节点的分类
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
# 创建树
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet] # 取得所有类别
if classList.count(classList[0]) == len(classList): # 类别完全相同,停止划分
return classList[0]
if len(dataSet[0]) == 1: # 遍历完所有特征时返回出现次数最多的类别
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet) # 最优特征的索引值
bestFeatLabel = labels[bestFeat] # 最优特征的标签
myTree = {bestFeatLabel: {}} # 使用字典类型储存树的信息
del(labels[bestFeat]) # 删除已经使用的特征标签
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
# 递归调用创建决策树函数
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
# 测试
if __name__ == '__main__':
dataSet, labels = createDataSet()
myTree = createTree(dataSet, labels)
print(myTree)
```
代码的主要实现过程如下:
1. `calcShannonEnt(dataSet)`:计算数据集的香农熵,即公式 $H=-\sum_{i=1}^{n}p_i\log_2p_i$,其中 $n$ 是类别数目,$p_i$ 是第 $i$ 个类别占总数的比例。
2. `splitDataSet(dataSet, axis, value)`:按照给定特征划分数据集,即找出数据集中第 $axis$ 个特征取值等于 $value$ 的所有实例。
3. `chooseBestFeatureToSplit(dataSet)`:选择最好的数据集划分方式,即计算每个特征的信息增益,选取信息增益最大的特征作为划分依据。
4. `createTree(dataSet, labels)`:使用递归方法建立决策树。建树过程中,如果所有样本属于同一类别,则返回该类别;如果遍历完所有特征,仍有多个类别,使用多数表决规则确定叶子节点的类别;否则选择信息增益最大的特征作为划分依据,递归创建左子树和右子树。
其中,关键步骤 `calcShannonEnt()`、`splitDataSet()` 和 `chooseBestFeatureToSplit()` 都有详细的注释,可以结合代码理解。
阅读全文