根据现实生活中的原型自己创建一个简单的决策树,要求用这个决策树能解决实际分类决策问题 用Python写代码 写的完善一点
时间: 2024-04-29 11:21:27 浏览: 99
好的,我可以给您提供一个基于餐厅预订的决策树示例,用Python实现。
问题:在选择餐厅时,如何根据以下条件进行决策:
1. 是否需要预订?
2. 餐厅是否有特殊的活动或优惠?
3. 餐厅的口味是什么?
4. 餐厅的价格水平?
示例决策树代码如下:
```python
# 定义决策树节点类
class DecisionNode:
def __init__(self, question, true_branch, false_branch):
self.question = question
self.true_branch = true_branch
self.false_branch = false_branch
# 定义叶子节点类
class Leaf:
def __init__(self, value):
self.value = value
# 定义问题类
class Question:
def __init__(self, column, value):
self.column = column
self.value = value
def match(self, example):
val = example[self.column]
return val == self.value
# 定义数据集
restaurant_data = [
['Yes', 'Italian', 'Expensive', 'Yes'],
['No', 'Thai', 'Average', 'No'],
['No', 'Chinese', 'Cheap', 'No'],
['Yes', 'Mexican', 'Expensive', 'No'],
['Yes', 'American', 'Expensive', 'Yes'],
['No', 'Italian', 'Average', 'Yes'],
['Yes', 'Thai', 'Cheap', 'No']
]
# 定义函数构建决策树
def build_tree(data):
# 定义停止条件
if len(data) == 0:
return Leaf('No')
if all(d[-1] == 'Yes' for d in data):
return Leaf('Yes')
if all(d[-1] == 'No' for d in data):
return Leaf('No')
# 选择最佳问题进行分类
best_gain = 0
best_question = None
for column in range(len(data[0])-1):
values = set([d[column] for d in data])
for value in values:
question = Question(column, value)
true_rows = [d for d in data if question.match(d)]
false_rows = [d for d in data if not question.match(d)]
if len(true_rows) == 0 or len(false_rows) == 0:
continue
gain = information_gain(true_rows, false_rows)
if gain > best_gain:
best_gain, best_question = gain, question
# 递归构建子树
true_branch = build_tree([d for d in data if best_question.match(d)])
false_branch = build_tree([d for d in data if not best_question.match(d)])
return DecisionNode(best_question, true_branch, false_branch)
# 定义计算信息增益的函数
def information_gain(true_rows, false_rows):
p = len(true_rows) / (len(true_rows) + len(false_rows))
return entropy(true_rows + false_rows) - p*entropy(true_rows) - (1-p)*entropy(false_rows)
# 定义计算熵的函数
def entropy(rows):
from math import log2
counts = {}
for row in rows:
label = row[-1]
if label not in counts:
counts[label] = 0
counts[label] += 1
entropy = 0
for label in counts:
p = counts[label] / len(rows)
entropy -= p * log2(p)
return entropy
# 定义预测函数
def predict(node, example):
if isinstance(node, Leaf):
return node.value
if node.question.match(example):
return predict(node.true_branch, example)
else:
return predict(node.false_branch, example)
# 构建决策树
restaurant_tree = build_tree(restaurant_data)
# 预测结果
print(predict(restaurant_tree, ['Yes', 'Chinese', 'Cheap']))
```
该示例使用餐厅预订的条件,从中选择最佳问题进行分类,并构建了一个决策树。根据用户输入的餐厅信息,可以使用该决策树进行预测,输出该餐厅是否需要预订,即输出“Yes”或“No”。
阅读全文