对一串数据进行cart算法的完整python代码
时间: 2023-12-21 19:07:42 浏览: 72
基于python实现决策树算法CART
以下是使用CART算法对一串数据进行分类的完整Python代码:
```python
import numpy as np
from collections import Counter
class CARTDecisionTree:
def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.tree = None
def fit(self, X, y):
self.tree = self.build_tree(X, y)
def predict(self, X):
return np.array([self.predict_one(x, self.tree) for x in X])
def predict_one(self, x, node):
if node.is_leaf:
return node.label
else:
if x[node.feature] <= node.threshold:
return self.predict_one(x, node.left)
else:
return self.predict_one(x, node.right)
def build_tree(self, X, y, depth=0):
if depth == self.max_depth or y.size < self.min_samples_split or len(set(y)) == 1:
return Node(self.get_label(y), is_leaf=True)
best_feature, best_threshold = self.get_best_split(X, y)
if best_feature is None or best_threshold is None:
return Node(self.get_label(y), is_leaf=True)
left_indices = X[:, best_feature] <= best_threshold
right_indices = X[:, best_feature] > best_threshold
left = self.build_tree(X[left_indices], y[left_indices], depth+1)
right = self.build_tree(X[right_indices], y[right_indices], depth+1)
return Node(None, feature=best_feature, threshold=best_threshold, left=left, right=right)
def get_best_split(self, X, y):
best_feature, best_threshold = None, None
best_gini = 1.0
for feature in range(X.shape[1]):
thresholds = np.unique(X[:, feature])
for threshold in thresholds:
gini = self.gini_index(X, y, feature, threshold)
if gini < best_gini:
best_feature, best_threshold, best_gini = feature, threshold, gini
if best_gini == 1.0:
return None, None
else:
return best_feature, best_threshold
def gini_index(self, X, y, feature, threshold):
left_indices = X[:, feature] <= threshold
right_indices = X[:, feature] > threshold
left_gini = self.get_gini(y[left_indices])
right_gini = self.get_gini(y[right_indices])
n = y.size
left_size, right_size = y[left_indices].size, y[right_indices].size
gini = (left_size/n)*left_gini + (right_size/n)*right_gini
return gini
def get_gini(self, y):
counter = Counter(y)
return 1 - sum([(v/len(y))**2 for v in counter.values()])
def get_label(self, y):
counter = Counter(y)
most_common = counter.most_common(1)[0][0]
return most_common
if __name__ == '__main__':
# 构造数据
X = np.array([[1, 1], [1, 2], [2, 1], [2, 2], [3, 1], [3, 2], [4, 1], [4, 2]])
y = np.array([0, 0, 0, 0, 1, 1, 1, 1])
# 训练模型
clf = CARTDecisionTree()
clf.fit(X, y)
# 预测新数据
x_new = np.array([[1, 1.5], [3.5, 1.5]])
y_pred = clf.predict(x_new)
print(y_pred)
```
这个代码首先构造了一个二维数据集`X`和对应的标签`y`,然后使用CART算法训练了一个决策树模型,最后使用这个模型对两个新的数据点进行了分类,并打印出了预测结果。
阅读全文