用python写决策树算法
时间: 2023-09-30 12:03:41 浏览: 15
决策树算法是一种非常常用的机器学习算法,可以用于分类和回归问题。下面是一个用Python语言实现决策树算法的示例代码。
首先,我们需要定义一个节点类,用于存储决策树的节点信息。
```python
class Node:
def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
self.value = value
```
其中,feature表示该节点选择的特征,threshold表示该特征的阈值,left和right分别表示该节点的左右子树,value表示该节点的值(用于叶子节点)。
然后,我们需要定义一个决策树类,用于实现决策树算法。
```python
class DecisionTree:
def __init__(self, max_depth=None):
self.max_depth = max_depth
self.root = None
def fit(self, X, y):
self.root = self._build_tree(X, y)
def predict(self, X):
return [self._predict(inputs) for inputs in X]
def _build_tree(self, X, y, depth=0):
n_samples, n_features = X.shape
n_labels = len(set(y))
if (self.max_depth is not None and depth >= self.max_depth) or n_labels == 1 or n_samples < 2:
leaf_value = self._majority_vote(y)
return Node(value=leaf_value)
feature_indices = np.random.choice(n_features, int(np.sqrt(n_features)), replace=False)
best_feature, best_threshold = self._best_criteria(X, y, feature_indices)
left_indices, right_indices = self._split(X[:, best_feature], best_threshold)
left = self._build_tree(X[left_indices, :], y[left_indices], depth+1)
right = self._build_tree(X[right_indices, :], y[right_indices], depth+1)
return Node(best_feature, best_threshold, left, right)
def _best_criteria(self, X, y, feature_indices):
best_gain = -1
split_idx, split_threshold = None, None
for feature_index in feature_indices:
X_column = X[:, feature_index]
thresholds = np.unique(X_column)
for threshold in thresholds:
gain = self._information_gain(y, X_column, threshold)
if gain > best_gain:
best_gain = gain
split_idx = feature_index
split_threshold = threshold
return split_idx, split_threshold
def _split(self, X_column, threshold):
left = np.argwhere(X_column <= threshold).flatten()
right = np.argwhere(X_column > threshold).flatten()
return left, right
def _information_gain(self, y, X_column, split_threshold):
parent_entropy = self._entropy(y)
left_indices, right_indices = self._split(X_column, split_threshold)
if len(left_indices) == 0 or len(right_indices) == 0:
return 0
n = len(y)
n_l, n_r = len(left_indices), len(right_indices)
e_l, e_r = self._entropy(y[left_indices]), self._entropy(y[right_indices])
child_entropy = (n_l / n) * e_l + (n_r / n) * e_r
ig = parent_entropy - child_entropy
return ig
def _entropy(self, y):
hist = np.bincount(y)
ps = hist / np.sum(hist)
return -np.sum([p * np.log2(p) for p in ps if p > 0])
def _majority_vote(self, y):
most_common = np.bincount(y).argmax()
return most_common
def _predict(self, inputs):
node = self.root
while node.left:
if inputs[node.feature] <= node.threshold:
node = node.left
else:
node = node.right
return node.value
```
其中,fit方法用于训练决策树,predict方法用于预测,_build_tree方法用于构建决策树。_best_criteria方法用于计算最佳分裂特征和阈值,_split方法用于根据特征和阈值分裂数据集,_information_gain方法用于计算信息增益,_entropy方法用于计算熵,_majority_vote方法用于计算叶子节点的值,_predict方法用于预测输入数据的类别。
最后,我们可以使用上述代码来训练和预测数据。
```python
import numpy as np
from sklearn.datasets import load_iris
iris = load_iris()
X, y = iris.data, iris.target
model = DecisionTree(max_depth=4)
model.fit(X, y)
y_pred = model.predict(X)
accuracy = np.mean(y_pred == y)
print("Accuracy:", accuracy)
```
相关推荐
















