cart生成算法代码python
时间: 2023-07-23 07:07:41 浏览: 70
以下是一个简单的 CART 决策树生成算法的 Python 代码实现:
```python
from collections import Counter
import numpy as np
class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, label=None):
self.feature = feature # 用于分割的特征
self.threshold = threshold # 分割特征的阈值
self.left = left # 左子树
self.right = right # 右子树
self.label = label # 叶节点的类别
class CART:
def __init__(self, max_depth=5, min_samples_split=2):
self.max_depth = max_depth # 树的最大深度
self.min_samples_split = min_samples_split # 最小分裂样本数
# 计算基尼不纯度
def gini(self, y):
_, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return 1 - np.sum(probs ** 2)
# 计算加权基尼不纯度
def weighted_gini(self, y_left, y_right):
n_left, n_right = len(y_left), len(y_right)
gini_left, gini_right = self.gini(y_left), self.gini(y_right)
return (n_left * gini_left + n_right * gini_right) / (n_left + n_right)
# 寻找最优的分割点
def find_best_split(self, X, y):
best_feature, best_threshold, best_gini = None, None, np.inf
for feature in range(X.shape[1]):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
left_indices = X[:, feature] < threshold
right_indices = X[:, feature] >= threshold
if len(y[left_indices]) >= self.min_samples_split and len(y[right_indices]) >= self.min_samples_split:
gini = self.weighted_gini(y[left_indices], y[right_indices])
if gini < best_gini:
best_feature, best_threshold, best_gini = feature, threshold, gini
return best_feature, best_threshold, best_gini
# 递归构建决策树
def build_tree(self, X, y, depth):
if depth == self.max_depth or len(y) < self.min_samples_split or len(set(y)) == 1:
label = Counter(y).most_common(1)[0][0]
return Node(label=label)
feature, threshold, gini = self.find_best_split(X, y)
left_indices = X[:, feature] < threshold
right_indices = X[:, feature] >= threshold
left = self.build_tree(X[left_indices], y[left_indices], depth + 1)
right = self.build_tree(X[right_indices], y[right_indices], depth + 1)
return Node(feature=feature, threshold=threshold, left=left, right=right)
# 训练决策树
def fit(self, X, y):
self.root = self.build_tree(X, y, 0)
# 预测单个样本
def predict_one(self, x):
node = self.root
while node.label is None:
if x[node.feature] < node.threshold:
node = node.left
else:
node = node.right
return node.label
# 批量预测
def predict(self, X):
return np.array([self.predict_one(x) for x in X])
```
使用示例:
```python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载数据集
iris = load_iris()
X, y = iris.data, iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 训练决策树模型
model = CART(max_depth=5, min_samples_split=2)
model.fit(X_train, y_train)
# 预测测试集
y_pred = model.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy:", accuracy)
```