在不用sklearn库的情况下写一个决策树代码
时间: 2023-12-04 22:38:57 浏览: 82
当然可以!以下是一个简单的决策树的实现代码示例:
```python
class DecisionTree:
def __init__(self, max_depth=None):
self.max_depth = max_depth
def fit(self, X, y):
self.X = X
self.y = y
self.n_features = X.shape[1]
self.n_classes = len(set(y))
self.tree = self._grow_tree()
def _grow_tree(self, depth=0):
num_samples_per_class = [np.sum(self.y == i) for i in range(self.n_classes)]
predicted_class = np.argmax(num_samples_per_class)
node = {'predicted_class': predicted_class}
if depth < self.max_depth:
best_gain = 0.0
best_feature = None
best_threshold = None
for feature in range(self.n_features):
unique_values = set(self.X[:, feature])
for threshold in unique_values:
gain = self._information_gain(feature, threshold)
if gain > best_gain:
best_gain = gain
best_feature = feature
best_threshold = threshold
if best_gain > 0.0:
left_indices = self.X[:, best_feature] <= best_threshold
right_indices = ~left_indices
node['feature'] = best_feature
node['threshold'] = best_threshold
node['left'] = self._grow_tree(depth + 1)
node['right'] = self._grow_tree(depth + 1)
return node
def _information_gain(self, feature, threshold):
parent_entropy = self._entropy(self.y)
left_indices = self.X[:, feature] <= threshold
right_indices = ~left_indices
left_entropy = self._entropy(self.y[left_indices])
right_entropy = self._entropy(self.y[right_indices])
n = len(self.y)
left_weight = len(self.y[left_indices]) / n
right_weight = len(self.y[right_indices]) / n
information_gain = parent_entropy - (left_weight * left_entropy + right_weight * right_entropy)
return information_gain
def _entropy(self, y):
class_counts = np.bincount(y)
probabilities = class_counts / len(y)
entropy = 0.0
for prob in probabilities:
if prob > 0:
entropy -= prob * np.log2(prob)
return entropy
def predict(self, X):
return np.array([self._traverse_tree(x, self.tree) for x in X])
def _traverse_tree(self, x, node):
if 'predicted_class' in node:
return node['predicted_class']
else:
if x[node['feature']] <= node['threshold']:
return self._traverse_tree(x, node['left'])
else:
return self._traverse_tree(x, node['right'])
```
这是一个基本的决策树实现,其中没有使用任何sklearn库。你可以使用`fit`方法训练决策树模型,然后使用`predict`方法对新的样本进行预测。注意,这个实现是一个简化版本,可能不具备sklearn中决策树的所有功能和优化。
阅读全文