增益率决策树代码实现
时间: 2024-04-19 09:21:51 浏览: 9
增益率决策树是一种常用的机器学习算法,用于分类和回归问题。下面是一个简单的增益率决策树的代码实现示例:
```python
import numpy as np
class DecisionTree:
def __init__(self):
self.tree = {}
def calc_entropy(self, y):
classes, counts = np.unique(y, return_counts=True)
probabilities = counts / len(y)
entropy = -np.sum(probabilities * np.log2(probabilities))
return entropy
def calc_gain_ratio(self, X, y, feature_idx):
feature_values = np.unique(X[:, feature_idx])
entropy = self.calc_entropy(y)
split_entropy = 0
intrinsic_value = 0
for value in feature_values:
subset_indices = np.where(X[:, feature_idx] == value)
subset_y = y[subset_indices]
subset_entropy = self.calc_entropy(subset_y)
subset_probability = len(subset_y) / len(y)
split_entropy += subset_probability * subset_entropy
intrinsic_value -= subset_probability * np.log2(subset_probability)
gain = entropy - split_entropy
gain_ratio = gain / intrinsic_value if intrinsic_value != 0 else 0
return gain_ratio
def choose_best_feature(self, X, y):
num_features = X.shape
best_feature_idx = -1
best_gain_ratio = -1
for feature_idx in range(num_features):
gain_ratio = self.calc_gain_ratio(X, y, feature_idx)
if gain_ratio > best_gain_ratio:
best_gain_ratio = gain_ratio
best_feature_idx = feature_idx
return best_feature_idx
def build_tree(self, X, y):
if len(np.unique(y)) == 1:
return y
if X.shape == 0:
return np.argmax(np.bincount(y))
best_feature_idx = self.choose_best_feature(X, y)
best_feature_values = np.unique(X[:, best_feature_idx])
tree = {best_feature_idx: {}}
for value in best_feature_values:
subset_indices = np.where(X[:, best_feature_idx] == value)[0]
subset_X = X[subset_indices]
subset_y = y[subset_indices]
tree[best_feature_idx][value] = self.build_tree(subset_X, subset_y)
return tree
def fit(self, X, y):
self.tree = self.build_tree(X, y)
def predict(self, X):
predictions = []
for sample in X:
node = self.tree
while isinstance(node, dict):
feature_idx = list(node.keys())
feature_value = sample[feature_idx]
node = node[feature_idx][feature_value]
predictions.append(node)
return np.array(predictions)
```
这段代码实现了一个简单的增益率决策树算法。其中,`calc_entropy`函数用于计算数据集的熵,`calc_gain_ratio`函数用于计算特征的增益率,`choose_best_feature`函数用于选择最佳特征,`build_tree`函数用于递归构建决策树,`fit`函数用于训练模型,`predict`函数用于进行预测。