决策树迭代python
时间: 2023-05-30 22:02:18 浏览: 59
决策树迭代算法的实现可以使用Python编程语言。以下是一个基本的决策树迭代算法的Python代码示例:
```python
import numpy as np
class TreeNode:
def __init__(self, feature_index=None, threshold=None, left=None, right=None, value=None):
self.feature_index = feature_index # 用于划分的特征索引
self.threshold = threshold # 用于划分的阈值
self.left = left # 左子树
self.right = right # 右子树
self.value = value # 叶节点的预测值
class DecisionTree:
def __init__(self, max_depth=None, min_samples_split=2):
self.max_depth = max_depth # 决策树最大深度
self.min_samples_split = min_samples_split # 最小样本划分数量
def fit(self, X, y):
self.n_features = X.shape[1] # 特征数量
self.tree = self._grow_tree(X, y) # 构建决策树
def predict(self, X):
return [self._predict(inputs) for inputs in X]
def _grow_tree(self, X, y, depth=0):
n_samples, n_features = X.shape
n_labels = len(np.unique(y))
# 如果样本数量小于最小划分数量或当前深度达到最大深度,则返回叶节点
if n_samples < self.min_samples_split or depth == self.max_depth or n_labels == 1:
leaf_value = self._leaf_value(y)
return TreeNode(value=leaf_value)
# 选择最佳的特征和阈值进行划分
feature_indices = np.random.choice(n_features, self.n_features, replace=False)
best_feature, best_threshold = self._best_criteria(X, y, feature_indices)
# 根据最佳特征和阈值划分数据集
left_indices, right_indices = self._split(X[:, best_feature], best_threshold)
left = self._grow_tree(X[left_indices, :], y[left_indices], depth + 1)
right = self._grow_tree(X[right_indices, :], y[right_indices], depth + 1)
return TreeNode(best_feature, best_threshold, left, right)
def _best_criteria(self, X, y, feature_indices):
best_gain = -1
split_idx, split_threshold = None, None
for feature_index in feature_indices:
X_column = X[:, feature_index]
thresholds = np.unique(X_column)
for threshold in thresholds:
gain = self._information_gain(y, X_column, threshold)
if gain > best_gain:
best_gain = gain
split_idx = feature_index
split_threshold = threshold
return split_idx, split_threshold
def _information_gain(self, y, X_column, split_threshold):
parent_entropy = self._entropy(y)
left_indices, right_indices = self._split(X_column, split_threshold)
if len(left_indices) == 0 or len(right_indices) == 0:
return 0
n = len(y)
nl, nr = len(left_indices), len(right_indices)
el, er = self._entropy(y[left_indices]), self._entropy(y[right_indices])
child_entropy = (nl / n) * el + (nr / n) * er
ig = parent_entropy - child_entropy
return ig
def _split(self, X_column, split_threshold):
left_indices = np.argwhere(X_column <= split_threshold).flatten()
right_indices = np.argwhere(X_column > split_threshold).flatten()
return left_indices, right_indices
def _entropy(self, y):
_, counts = np.unique(y, return_counts=True)
p = counts / len(y)
entropy = -np.sum(p * np.log2(p))
return entropy
def _leaf_value(self, y):
_, counts = np.unique(y, return_counts=True)
most_common_label = y[np.argmax(counts)]
return most_common_label
def _predict(self, inputs):
node = self.tree
while node.left:
if inputs[node.feature_index] <= node.threshold:
node = node.left
else:
node = node.right
return node.value
```
这个示例代码使用了numpy库和一个TreeNode类和一个DecisionTree类。在DecisionTree类中,`fit`方法用于训练决策树,`predict`方法用于预测新的数据。`_grow_tree`方法是决策树递归构建的核心函数,`_best_criteria`方法选择最佳的特征和阈值进行划分,`_information_gain`方法计算信息增益,`_split`方法根据阈值划分数据集,`_entropy`方法计算熵,`_leaf_value`方法计算叶节点的预测值,`_predict`方法使用构建好的决策树对输入数据进行预测。