西瓜数据集决策树python不调库代码实现
时间: 2023-12-08 19:03:01 浏览: 182
python使用ID3、C4.5、CART实现西瓜数据集决策树并画出图像
以下是使用Python实现西瓜数据集决策树的代码,不需要调用任何库:
```python
import math
# 定义节点类
class Node:
def __init__(self, feature=None, value=None, results=None, tb=None, fb=None):
self.feature = feature # 用于划分数据集的特征
self.value = value # 特征的值
self.results = results # 存储叶子节点的分类结果
self.tb = tb # 左子树
self.fb = fb # 右子树
# 计算数据集的熵
def entropy(data):
results = {}
for row in data:
r = row[-1]
if r not in results:
results[r] = 0
results[r] += 1
ent = 0.0
for r in results:
p = float(results[r]) / len(data)
ent -= p * math.log(p, 2)
return ent
# 根据特征和特征值划分数据集
def divide_data(data, feature, value):
split_func = None
if isinstance(value, int) or isinstance(value, float):
split_func = lambda row: row[feature] >= value
else:
split_func = lambda row: row[feature] == value
set1 = [row for row in data if split_func(row)]
set2 = [row for row in data if not split_func(row)]
return (set1, set2)
# 选择最好的特征和特征值来划分数据集
def find_best_feature(data):
best_feature = -1
best_value = None
best_gain = 0.0
base_entropy = entropy(data)
for feature in range(len(data[0]) - 1):
feature_values = set([row[feature] for row in data])
for value in feature_values:
set1, set2 = divide_data(data, feature, value)
p = float(len(set1)) / len(data)
gain = base_entropy - p * entropy(set1) - (1 - p) * entropy(set2)
if gain > best_gain:
best_feature = feature
best_value = value
best_gain = gain
return (best_feature, best_value)
# 构建决策树
def build_tree(data):
if len(data) == 0:
return Node()
results = [row[-1] for row in data]
if results.count(results[0]) == len(results):
return Node(results=results[0])
best_feature, best_value = find_best_feature(data)
set1, set2 = divide_data(data, best_feature, best_value)
tb = build_tree(set1)
fb = build_tree(set2)
return Node(feature=best_feature, value=best_value, tb=tb, fb=fb)
# 打印决策树
def print_tree(tree, indent=''):
if tree.results is not None:
print(str(tree.results))
else:
print(str(tree.feature) + ':' + str(tree.value) + '? ')
print(indent + 'T->', end='')
print_tree(tree.tb, indent + ' ')
print(indent + 'F->', end='')
print_tree(tree.fb, indent + ' ')
# 对新数据进行分类
def classify(tree, data):
if tree.results is not None:
return tree.results
else:
v = data[tree.feature]
branch = None
if isinstance(v, int) or isinstance(v, float):
if v >= tree.value:
branch = tree.tb
else:
branch = tree.fb
else:
if v == tree.value:
branch = tree.tb
else:
branch = tree.fb
return classify(branch, data)
# 测试决策树
def test_tree(tree, test_data):
correct = 0
for row in test_data:
if classify(tree, row[:-1]) == row[-1]:
correct += 1
accuracy = float(correct) / len(test_data)
print('Accuracy: %.2f%%' % (accuracy * 100))
# 加载西瓜数据集
def load_watermelon():
data = [
[1, 1, 1, 1, 'yes'],
[1, 1, 1, 0, 'yes'],
[1, 0, 1, 0, 'no'],
[0, 1, 0, 1, 'no'],
[0, 1, 0, 0, 'no'],
[0, 0, 1, 1, 'no'],
[0, 1, 1, 0, 'no'],
[1, 1, 0, 1, 'no'],
[1, 0, 0, 0, 'no'],
[0, 1, 0, 1, 'no']
]
return data
# 加载西瓜数据集2
def load_watermelon2():
data = [
[0.697, 0.460, 1, 'yes'],
[0.774, 0.376, 1, 'yes'],
[0.634, 0.264, 1, 'yes'],
[0.608, 0.318, 1, 'yes'],
[0.556, 0.215, 1, 'yes'],
[0.403, 0.237, 1, 'yes'],
[0.481, 0.149, 1, 'yes'],
[0.437, 0.211, 1, 'yes'],
[0.666, 0.091, 0, 'no'],
[0.243, 0.267, 0, 'no'],
[0.245, 0.057, 0, 'no'],
[0.343, 0.099, 0, 'no'],
[0.639, 0.161, 0, 'no'],
[0.657, 0.198, 0, 'no'],
[0.360, 0.370, 0, 'no'],
[0.593, 0.042, 0, 'no'],
[0.719, 0.103, 0, 'no']
]
return data
# 加载西瓜数据集3
def load_watermelon3():
data = [
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],
['浅白', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '是'],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '是'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '否'],
['青绿', '硬挺', '清脆', '清晰', '平坦', '硬滑', '否'],
['浅白', '硬挺', '清脆', '模糊', '平坦', '软粘', '否'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '否'],
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '否'],
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '否'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '否'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '否'],
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '否']
]
return data
# 加载西瓜数据集4
def load_watermelon4():
data = [
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '是'],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '是'],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '是'],
['浅白', '稍蜷', '浊响', '清晰',
阅读全文