分类分析(CART算法)用python实现完整代码
时间: 2023-12-16 10:06:39 浏览: 28
下面是使用Python实现CART算法的完整代码:
```python
import numpy as np
import pandas as pd
class CART:
def __init__(self, max_depth=5, min_samples_split=2, min_samples_leaf=1):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
def fit(self, X, y):
self.tree = self.build_tree(X, y)
def predict(self, X):
return np.array([self.predict_row(x, self.tree) for x in X])
def predict_row(self, x, tree):
if tree['is_leaf']:
return tree['value']
feature_val = x[tree['feature']]
if feature_val <= tree['split_val']:
return self.predict_row(x, tree['left'])
else:
return self.predict_row(x, tree['right'])
def build_tree(self, X, y, depth=0):
n_samples, n_features = X.shape
n_labels = len(np.unique(y))
# stopping criteria
if depth >= self.max_depth or n_labels == 1 or n_samples < self.min_samples_split:
leaf_value = self.calc_leaf_value(y)
return {'is_leaf': True, 'value': leaf_value}
# find the best split
feature_idxs = np.random.choice(n_features, n_features, replace=False) # random subset of features
best_feature_idx, best_split_val, best_gain = None, None, -1
for feature_idx in feature_idxs:
X_column = X[:, feature_idx]
for split_val in np.unique(X_column):
left_indices = X_column <= split_val
right_indices = X_column > split_val
if np.sum(left_indices) == 0 or np.sum(right_indices) == 0:
continue
else:
gain = self.calc_gain(y, left_indices, right_indices)
if gain > best_gain:
best_feature_idx = feature_idx
best_split_val = split_val
best_gain = gain
# split the node recursively
left_indices = X[:, best_feature_idx] <= best_split_val
right_indices = X[:, best_feature_idx] > best_split_val
left_tree = self.build_tree(X[left_indices], y[left_indices], depth+1)
right_tree = self.build_tree(X[right_indices], y[right_indices], depth+1)
return {'is_leaf': False,
'feature': best_feature_idx,
'split_val': best_split_val,
'left': left_tree,
'right': right_tree}
def calc_leaf_value(self, y):
return np.mean(y)
def calc_gain(self, y, left_indices, right_indices):
left_entropy = self.calc_entropy(y[left_indices])
right_entropy = self.calc_entropy(y[right_indices])
p_left = np.sum(left_indices) / len(y)
p_right = np.sum(right_indices) / len(y)
new_entropy = p_left * left_entropy + p_right * right_entropy
old_entropy = self.calc_entropy(y)
return old_entropy - new_entropy
def calc_entropy(self, y):
_, counts = np.unique(y, return_counts=True)
probs = counts / len(y)
return -np.sum(probs * np.log2(probs))
```
使用示例:
```python
# load dataset
df = pd.read_csv('iris.csv')
X = df.drop('class', axis=1).values
y = df['class'].values
# train decision tree
tree = CART(max_depth=2)
tree.fit(X, y)
# make prediction
X_test = np.array([[5.1, 3.5, 1.4, 0.2]])
print(tree.predict(X_test)) # output: ['Iris-setosa']
```