手写python决策树算法,能够对定类数据进行应用
时间: 2024-05-04 20:21:40 浏览: 36
决策树算法实现(使用MNIST数据集)_Python环境
好的,下面是一个简单的Python决策树算法,可以对定类数据进行分类:
```python
# 导入必要的库
import numpy as np
# 定义节点类
class Node:
def __init__(self, feature=None, value=None, results=None, tb=None, fb=None):
self.feature = feature # 分裂特征
self.value = value # 分裂特征的值
self.results = results # 叶节点存储的结果
self.tb = tb # 左子树
self.fb = fb # 右子树
# 定义决策树算法
def buildTree(data):
if len(data) == 0:
return Node()
currentScore = scoreFunction(data) # 计算当前数据集的得分
bestGain = 0.0
bestCriteria = None
bestSets = None
for column in range(len(data[0])-1):
columnValues = {}
for row in data:
columnValues[row[column]] = 1
for value in columnValues.keys():
(set1, set2) = divideSet(data, column, value)
p = float(len(set1)) / len(data)
gain = currentScore - p * scoreFunction(set1) - (1 - p) * scoreFunction(set2)
if gain > bestGain and len(set1) > 0 and len(set2) > 0:
bestGain = gain
bestCriteria = (column, value)
bestSets = (set1, set2)
if bestGain > 0:
trueBranch = buildTree(bestSets[0])
falseBranch = buildTree(bestSets[1])
return Node(feature=bestCriteria[0], value=bestCriteria[1], tb=trueBranch, fb=falseBranch)
else:
return Node(results=uniqueCounts(data))
# 定义分类函数
def classify(observation, tree):
if tree.results != None:
return tree.results
else:
v = observation[tree.feature]
branch = None
if isinstance(v, int) or isinstance(v, float):
if v >= tree.value:
branch = tree.tb
else:
branch = tree.fb
else:
if v == tree.value:
branch = tree.tb
else:
branch = tree.fb
return classify(observation, branch)
# 定义计算数据集得分的函数
def scoreFunction(rows):
if len(rows) == 0:
return 0
counts = uniqueCounts(rows)
imp = 0.0
for k1 in counts:
p1 = float(counts[k1]) / len(rows)
for k2 in counts:
if k1 == k2:
continue
p2 = float(counts[k2]) / len(rows)
imp += p1 * p2
return imp
# 定义将数据集根据特征和特征值分裂的函数
def divideSet(rows, column, value):
splitFunction = None
if isinstance(value, int) or isinstance(value, float):
splitFunction = lambda row:row[column] >= value
else:
splitFunction = lambda row:row[column] == value
set1 = [row for row in rows if splitFunction(row)]
set2 = [row for row in rows if not splitFunction(row)]
return (set1, set2)
# 定义计算数据集中每个类别出现次数的函数
def uniqueCounts(rows):
results = {}
for row in rows:
r = row[-1]
if r not in results:
results[r] = 0
results[r] += 1
return results
```
这个决策树算法使用“基尼不纯度”作为评估标准,可以对定类数据进行分类。使用方法如下:
```python
# 准备数据
data = [
[2.771244718,1.784783929,0],
[1.728571309,1.169761413,0],
[3.678319846,2.81281357,0],
[3.961043357,2.61995032,0],
[2.999208922,2.209014212,0],
[7.497545867,3.162953546,1],
[9.00220326,3.339047188,1],
[7.444542326,0.476683375,1],
[10.12493903,3.234550982,1],
[6.642287351,3.319983761,1]
]
# 构建决策树
tree = buildTree(data)
# 对新数据进行分类
newData = [2.771244718,1.784783929]
result = classify(newData, tree)
print(result) # 输出0,代表属于第一类
```
以上代码仅作为示例,实际使用时可能需要根据具体需求进行修改和优化。
阅读全文