给出python实现决策树算法的代码,不调用sklearn库,对鸢尾花数据集分类,实现结果可视化
时间: 2024-05-08 11:19:00 浏览: 13
以下是Python实现决策树算法的代码,不调用sklearn库,并对鸢尾花数据集进行分类并可视化结果。
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 读取数据
data = pd.read_csv('iris.csv')
X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values
# 决策树节点类
class DecisionNode:
def __init__(self, feature_idx=None, threshold=None, left=None, right=None, result=None):
self.feature_idx = feature_idx # 特征索引
self.threshold = threshold # 阈值
self.left = left # 左子节点
self.right = right # 右子节点
self.result = result # 叶子节点的类别
# 计算数据集的基尼系数
def calculate_gini(y):
classes = np.unique(y)
n = len(y)
gini = 1
for cls in classes:
p = len(y[y == cls]) / n
gini -= p ** 2
return gini
# 根据特征和阈值对数据集进行分类
def split_data(X, y, feature_idx, threshold):
left_idxs = np.where(X[:, feature_idx] <= threshold)[0]
right_idxs = np.where(X[:, feature_idx] > threshold)[0]
left_X, left_y = X[left_idxs], y[left_idxs]
right_X, right_y = X[right_idxs], y[right_idxs]
return left_X, left_y, right_X, right_y
# 根据基尼系数选择最优切分特征和阈值
def select_split(X, y):
best_gini = float('inf')
best_feature_idx = None
best_threshold = None
for feature_idx in range(X.shape[1]):
feature_values = np.unique(X[:, feature_idx])
for threshold in feature_values:
left_X, left_y, right_X, right_y = split_data(X, y, feature_idx, threshold)
gini = len(left_y) / len(y) * calculate_gini(left_y) + len(right_y) / len(y) * calculate_gini(right_y)
if gini < best_gini:
best_gini = gini
best_feature_idx = feature_idx
best_threshold = threshold
return best_feature_idx, best_threshold
# 构建决策树
def build_tree(X, y):
if len(np.unique(y)) == 1: # 只有一个类别,返回叶子节点
return DecisionNode(result=y[0])
feature_idx, threshold = select_split(X, y)
left_X, left_y, right_X, right_y = split_data(X, y, feature_idx, threshold)
left_tree = build_tree(left_X, left_y)
right_tree = build_tree(right_X, right_y)
return DecisionNode(feature_idx=feature_idx, threshold=threshold, left=left_tree, right=right_tree)
# 预测单个样本的类别
def predict_sample(tree, sample):
if tree.result is not None:
return tree.result
if sample[tree.feature_idx] <= tree.threshold:
return predict_sample(tree.left, sample)
else:
return predict_sample(tree.right, sample)
# 预测整个数据集的类别
def predict(tree, X):
n = X.shape[0]
y_pred = np.empty(n)
for i in range(n):
y_pred[i] = predict_sample(tree, X[i])
return y_pred
# 可视化分类结果
def plot(X, y_true, y_pred):
plt.scatter(X[:, 0], X[:, 1], c=y_true)
plt.scatter(X[:, 0], X[:, 1], c=y_pred, marker='x', cmap='tab10')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('Iris data classification')
plt.show()
# 划分训练集和测试集
np.random.seed(2021)
shuffle_idx = np.random.permutation(X.shape[0])
X = X[shuffle_idx]
y = y[shuffle_idx]
split_idx = int(0.7 * X.shape[0])
X_train, y_train = X[:split_idx], y[:split_idx]
X_test, y_test = X[split_idx:], y[split_idx:]
# 构建决策树
tree = build_tree(X_train, y_train)
# 预测并可视化结果
y_train_pred = predict(tree, X_train)
y_test_pred = predict(tree, X_test)
plot(X_train, y_train, y_train_pred)
plot(X_test, y_test, y_test_pred)
```
运行该代码会对鸢尾花数据集进行分类,并可视化结果。