基于python代码,针对鸢尾花数据集,实现C4.5,并对鸢尾花数据进行分类
时间: 2023-07-24 09:13:54 浏览: 115
python利用c4.5决策树对鸢尾花卉数据集进行分类(iris)(代码全)
5星 · 资源好评率100%
好的,以下是基于Python实现C4.5算法对鸢尾花数据集进行分类的示例代码:
```python
import pandas as pd
import numpy as np
from math import log2
class TreeNode:
def __init__(self, feature=None, threshold=None, label=None, left=None, right=None):
self.feature = feature
self.threshold = threshold
self.label = label
self.left = left
self.right = right
class C45:
def __init__(self, epsilon=1e-4):
self.epsilon = epsilon
self.tree = None
def fit(self, X, y):
self.tree = self._build_tree(X, y)
def predict(self, X):
y_pred = []
for i in range(len(X)):
y_pred.append(self._predict_one(X[i]))
return np.array(y_pred)
def _predict_one(self, x):
node = self.tree
while node.label is None:
if x[node.feature] <= node.threshold:
node = node.left
else:
node = node.right
return node.label
def _build_tree(self, X, y):
n_samples, n_features = X.shape
if len(set(y)) == 1:
return TreeNode(label=y[0])
if n_samples == 0:
return TreeNode(label=None)
if n_features == 0:
return TreeNode(label=self._majority_vote(y))
entropy = self._entropy(y)
max_gain_ratio = -1
best_feature = None
best_threshold = None
for i in range(n_features):
values = sorted(set(X[:, i]))
if len(values) == 1:
continue
for j in range(len(values) - 1):
threshold = (values[j] + values[j+1]) / 2
y_left = y[X[:, i] <= threshold]
y_right = y[X[:, i] > threshold]
if len(y_left) == 0 or len(y_right) == 0:
continue
gain = entropy - self._information_gain(y_left, y_right)
split_info = self._split_info(y_left, y_right)
if split_info == 0:
continue
gain_ratio = gain / split_info
if gain_ratio > max_gain_ratio:
max_gain_ratio = gain_ratio
best_feature = i
best_threshold = threshold
if best_feature is None:
return TreeNode(label=self._majority_vote(y))
X_left = X[X[:, best_feature] <= best_threshold]
y_left = y[X[:, best_feature] <= best_threshold]
X_right = X[X[:, best_feature] > best_threshold]
y_right = y[X[:, best_feature] > best_threshold]
if max_gain_ratio < self.epsilon:
return TreeNode(label=self._majority_vote(y))
left = self._build_tree(X_left, y_left)
right = self._build_tree(X_right, y_right)
return TreeNode(feature=best_feature, threshold=best_threshold, left=left, right=right)
def _entropy(self, y):
_, counts = np.unique(y, return_counts=True)
p = counts / len(y)
return -np.sum(p * np.log2(p))
def _information_gain(self, y_left, y_right):
p_left = len(y_left) / (len(y_left) + len(y_right))
p_right = len(y_right) / (len(y_left) + len(y_right))
return p_left * self._entropy(y_left) + p_right * self._entropy(y_right)
def _split_info(self, y_left, y_right):
p_left = len(y_left) / (len(y_left) + len(y_right))
p_right = len(y_right) / (len(y_left) + len(y_right))
if p_left == 0 or p_right == 0:
return 0
return - p_left * log2(p_left) - p_right * log2(p_right)
def _majority_vote(self, y):
return np.argmax(np.bincount(y))
```
在这个示例代码中,我们定义了一个`C45`类,其中包含了`fit`和`predict`方法用于训练和预测。在`fit`方法中,我们调用`_build_tree`方法来构建决策树。在`predict`方法中,我们遍历决策树来预测每个样本的类别。
接下来,我们可以使用鸢尾花数据集来测试这个算法:
```python
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
clf = C45()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print("Accuracy:", acc)
```
在这个测试代码中,我们首先使用`train_test_split`方法将数据集划分为训练集和测试集。然后,我们使用训练集来训练我们实现的C4.5算法,并使用测试集来测试算法的准确率。最后,我们打印出准确率。
这就是使用Python实现C4.5算法对鸢尾花数据集进行分类的示例代码。
阅读全文