用python语言写一个决策树算法实例代码
时间: 2023-12-29 19:25:29 浏览: 83
python 决策树实例代码
好的,以下是一个简单的决策树算法的Python实现:
```python
import pandas as pd
import numpy as np
class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
self.feature = feature # 分裂特征
self.threshold = threshold # 分裂阈值
self.left = left # 左子树
self.right = right # 右子树
self.value = value # 叶节点取值
class DecisionTree:
def __init__(self, max_depth=None):
self.max_depth = max_depth # 最大深度
self.tree = None # 决策树根节点
def fit(self, X, y):
self.tree = self._build_tree(X, y, depth=0)
def predict(self, X):
return np.array([self._predict(x) for x in X])
def _build_tree(self, X, y, depth):
n_samples, n_features = X.shape
n_labels = len(np.unique(y))
# 如果样本全部属于同一类别,或者达到了最大深度,则返回叶节点
if n_labels == 1 or depth == self.max_depth:
return Node(value=y[0])
# 找到最优的分裂特征和阈值
best_feature, best_threshold = self._find_best_split(X, y, n_samples, n_features)
# 如果无法分裂,则返回叶节点
if best_feature is None or best_threshold is None:
return Node(value=self._most_common_label(y))
# 递归构建左子树和右子树
left_indices = X[:, best_feature] <= best_threshold
right_indices = X[:, best_feature] > best_threshold
left = self._build_tree(X[left_indices], y[left_indices], depth + 1)
right = self._build_tree(X[right_indices], y[right_indices], depth + 1)
return Node(feature=best_feature, threshold=best_threshold, left=left, right=right)
def _find_best_split(self, X, y, n_samples, n_features):
best_feature = None
best_threshold = None
best_gini = 1.0
# 遍历所有特征和阈值,找到最优的分裂特征和阈值
for feature in range(n_features):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left_indices = X[:, feature] <= threshold
right_indices = X[:, feature] > threshold
if len(left_indices) == 0 or len(right_indices) == 0:
continue
gini = (len(left_indices) / n_samples) * self._gini(y[left_indices]) + \
(len(right_indices) / n_samples) * self._gini(y[right_indices])
if gini < best_gini:
best_feature = feature
best_threshold = threshold
best_gini = gini
return best_feature, best_threshold
def _gini(self, y):
n_samples = len(y)
labels, counts = np.unique(y, return_counts=True)
impurity = 1.0
for i in range(len(labels)):
p = counts[i] / n_samples
impurity -= p ** 2
return impurity
def _most_common_label(self, y):
labels, counts = np.unique(y, return_counts=True)
return labels[np.argmax(counts)]
def _predict(self, x):
node = self.tree
while node.value is None:
if x[node.feature] <= node.threshold:
node = node.left
else:
node = node.right
return node.value
```
使用方法如下:
```python
# 构造数据集
X = np.array([[2.0, 3.0],
[2.0, 1.0],
[1.0, 2.0],
[3.0, 1.0],
[3.0, 2.0]])
y = np.array([0, 1, 1, 1, 0])
# 构建决策树
dt = DecisionTree(max_depth=2)
dt.fit(X, y)
# 预测新样本
X_new = np.array([[1.0, 1.0],
[3.0, 3.0]])
y_pred = dt.predict(X_new)
print(y_pred) # [1 0]
```
阅读全文