决策树算法python实现
时间: 2023-05-30 17:06:15 浏览: 106
以下是一个简单的决策树算法的Python实现,其中使用了递归来构建树:
```
import pandas as pd
import numpy as np
class TreeNode:
def __init__(self, feature_idx=None, threshold=None, left=None, right=None, label=None):
self.feature_idx = feature_idx
self.threshold = threshold
self.left = left
self.right = right
self.label = label
class DecisionTree:
def __init__(self, max_depth=None, min_samples_leaf=1):
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.tree = None
def fit(self, X, y):
self.tree = self._build_tree(X, y)
def predict(self, X):
return np.array([self._predict_tree(x, self.tree) for x in X])
def _build_tree(self, X, y, depth=0):
n_samples, n_features = X.shape
n_labels = len(np.unique(y))
# If only one label left or max depth reached, return leaf node
if n_labels == 1 or depth == self.max_depth:
label = np.bincount(y).argmax()
return TreeNode(label=label)
# If not enough samples at this node, return leaf node with majority label
if n_samples < self.min_samples_leaf:
label = np.bincount(y).argmax()
return TreeNode(label=label)
# Find the best feature to split on
best_feature, best_threshold = self._find_best_split(X, y)
# If no split found, return leaf node with majority label
if best_feature is None:
label = np.bincount(y).argmax()
return TreeNode(label=label)
# Split data based on best feature and threshold
left_idxs = X[:, best_feature] < best_threshold
X_left, y_left = X[left_idxs], y[left_idxs]
X_right, y_right = X[~left_idxs], y[~left_idxs]
# Recursively build left and right subtrees
left = self._build_tree(X_left, y_left, depth+1)
right = self._build_tree(X_right, y_right, depth+1)
# Return node with best feature and threshold and left and right subtrees
return TreeNode(feature_idx=best_feature, threshold=best_threshold, left=left, right=right)
def _find_best_split(self, X, y):
best_gain = -np.inf
best_feature, best_threshold = None, None
n_samples, n_features = X.shape
# Loop over all features and possible thresholds
for feature_idx in range(n_features):
thresholds = np.unique(X[:, feature_idx])
for threshold in thresholds:
# Split data based on current feature and threshold
left_idxs = X[:, feature_idx] < threshold
y_left, y_right = y[left_idxs], y[~left_idxs]
# If one side is empty, skip this threshold
if len(y_left) == 0 or len(y_right) == 0:
continue
# Calculate information gain using Gini impurity
gain = self._gini_gain(y, y_left, y_right)
# Update best gain and feature and threshold if this split is better
if gain > best_gain:
best_gain = gain
best_feature = feature_idx
best_threshold = threshold
# If no split found, return None
if best_gain == -np.inf:
return None, None
# Return best feature and threshold
return best_feature, best_threshold
def _gini(self, y):
_, counts = np.unique(y, return_counts=True)
p = counts / len(y)
return 1 - np.sum(p ** 2)
def _gini_gain(self, y, y_left, y_right):
p = len(y_left) / len(y)
return self._gini(y) - p * self._gini(y_left) - (1 - p) * self._gini(y_right)
def _predict_tree(self, x, node):
if node.label is not None:
return node.label
if x[node.feature_idx] < node.threshold:
return self._predict_tree(x, node.left)
else:
return self._predict_tree(x, node.right)
```
在使用时,可以将数据集和标签传递给`fit`方法,然后使用`predict`方法进行预测。例如:
```
X = pd.read_csv('data.csv').values
y = pd.read_csv('labels.csv').values.ravel()
tree = DecisionTree(max_depth=3, min_samples_leaf=2)
tree.fit(X, y)
y_pred = tree.predict(X)
```