决策树算法python实现
时间: 2023-07-20 08:26:51 浏览: 86
决策树算法是一种基于树结构进行决策的算法,它可以用于分类和回归问题。以下是一个简单的决策树算法的 Python 实现:
```python
import numpy as np
class DecisionTree:
def __init__(self):
self.tree = None
def fit(self, X, y):
self.tree = self._build_tree(X, y)
def predict(self, X):
predictions = []
for sample in X:
node = self.tree
while node.left:
if sample[node.feature] <= node.threshold:
node = node.left
else:
node = node.right
predictions.append(node.value)
return predictions
def _build_tree(self, X, y):
n_samples, n_features = X.shape
if n_samples == 0:
return None
if len(np.unique(y)) == 1:
return Node(value=y[0])
best_feature, best_threshold = self._find_best_split(X, y)
left_idxs = X[:, best_feature] <= best_threshold
right_idxs = X[:, best_feature] > best_threshold
left = self._build_tree(X[left_idxs], y[left_idxs])
right = self._build_tree(X[right_idxs], y[right_idxs])
return Node(feature=best_feature, threshold=best_threshold, left=left, right=right)
def _find_best_split(self, X, y):
n_samples, n_features = X.shape
best_gain = -float('inf')
best_feature = None
best_threshold = None
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left_idxs = X[:, feature] <= threshold
right_idxs = X[:, feature] > threshold
left_purity = self._calculate_purity(y[left_idxs])
right_purity = self._calculate_purity(y[right_idxs])
gain = self._calculate_gain(y, left_idxs, right_idxs)
if gain > best_gain:
best_gain = gain
best_feature = feature
best_threshold = threshold
return best_feature, best_threshold
def _calculate_purity(self, y):
_, counts = np.unique(y, return_counts=True)
probabilities = counts / counts.sum()
return 1 - sum(probabilities ** 2)
def _calculate_gain(self, y, left_idxs, right_idxs):
left_purity = self._calculate_purity(y[left_idxs])
right_purity = self._calculate_purity(y[right_idxs])
n_left, n_right = len(left_idxs), len(right_idxs)
n_total = n_left + n_right
gain = self._calculate_entropy(y) - (n_left/n_total * left_purity + n_right/n_total * right_purity)
return gain
def _calculate_entropy(self, y):
_, counts = np.unique(y, return_counts=True)
probabilities = counts / counts.sum()
return sum(probabilities * -np.log2(probabilities))
class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
self.value = value
```
这里实现的决策树算法采用基尼不纯度作为分裂准则,采用递归的方式构建决策树。在构建决策树时,算法会选择最佳的分裂点,直到无法再分裂为止。在预测时,算法会遍历决策树,根据样本的特征值逐步向下走,直到叶子节点为止,叶子节点的值就是该样本的预测值。
阅读全文