利用西瓜数据集进行实验,训练生成决策树并进行测试,将数据集的1到7与10到17行作为训练集,剩下的为测试集。1)实现决策树学习算法,要求度量样本集合纯度指标采用“信息增益”,画出决策树并给出相应的测试结果。给出matlab代码
时间: 2024-02-21 09:02:03 浏览: 63
好的,以下是用MATLAB实现决策树学习算法的代码:
```matlab
% 载入数据集
load watermelon.mat;
% 将数据集的1到7与10到17行作为训练集,剩下的为测试集
train_data = [watermelon(1:7, :); watermelon(10:17, :)];
test_data = watermelon(setdiff(1:17, [1:7, 10:17]), :);
% 决策树学习算法
tree = create_tree(train_data, 1:size(watermelon, 2)-1);
% 画出决策树
view_tree(tree);
% 给出相应的测试结果
accuracy = test_tree(tree, test_data);
fprintf('测试准确率为 %.2f%%\n', accuracy * 100);
```
其中,`create_tree`、`view_tree` 和 `test_tree` 分别是创建决策树、画出决策树和测试决策树的函数。具体实现如下:
```matlab
% 创建决策树
function tree = create_tree(data, features)
% data为训练集,features为当前可用的特征
classes = unique(data(:, end)); % 类别
if length(classes) == 1 % 如果只有一类
tree = struct('is_leaf', true, 'class', classes(1));
return;
end
if isempty(features) % 如果没有特征可用
tree = struct('is_leaf', true, 'class', mode(data(:, end)));
return;
end
info_gain = zeros(1, length(features)); % 信息增益
for i = 1:length(features)
info_gain(i) = calc_info_gain(data, features(i));
end
[~, best_feature] = max(info_gain); % 选择信息增益最大的特征
tree = struct('is_leaf', false, 'feature', best_feature);
values = unique(data(:, best_feature)); % 特征取值
for i = 1:length(values)
sub_data = data(data(:, best_feature) == values(i), :);
sub_features = setdiff(features, best_feature);
tree.sub_trees{i} = create_tree(sub_data, sub_features);
end
end
% 计算信息熵
function entropy = calc_entropy(data)
classes = unique(data(:, end));
p = zeros(1, length(classes));
for i = 1:length(classes)
p(i) = sum(data(:, end) == classes(i)) / size(data, 1);
end
entropy = -dot(p, log2(p));
end
% 计算信息增益
function info_gain = calc_info_gain(data, feature)
values = unique(data(:, feature)); % 特征取值
sub_entropies = zeros(1, length(values));
for i = 1:length(values)
sub_data = data(data(:, feature) == values(i), :);
sub_entropies(i) = calc_entropy(sub_data);
end
weights = arrayfun(@(x) sum(data(:, feature) == values(x)) / size(data, 1), 1:length(values));
info_gain = calc_entropy(data) - dot(weights, sub_entropies);
end
% 画出决策树
function view_tree(tree)
if tree.is_leaf % 如果是叶子节点
fprintf(' %d ', tree.class);
else % 如果是中间节点
fprintf(' x%d ', tree.feature);
for i = 1:length(tree.sub_trees)
fprintf('\n\t');
view_tree(tree.sub_trees{i});
end
end
end
% 测试决策树
function accuracy = test_tree(tree, data)
num_correct = 0;
for i = 1:size(data, 1)
x = data(i, 1:end-1); % 输入
y = data(i, end); % 真实标签
y_pred = predict_tree(tree, x); % 预测标签
if y_pred == y
num_correct = num_correct + 1;
end
end
accuracy = num_correct / size(data, 1);
end
% 预测标签
function y_pred = predict_tree(tree, x)
if tree.is_leaf % 如果是叶子节点
y_pred = tree.class;
return;
end
feature_value = x(tree.feature);
sub_tree = tree.sub_trees{feature_value};
y_pred = predict_tree(sub_tree, x);
end
```
运行后,会输出决策树和测试准确率。其中,决策树可以看作是一个嵌套的结构体,每个节点有三个属性:is_leaf(是否是叶子节点)、feature(特征编号,如果是叶子节点则为0)和 sub_trees(子树数组,如果是叶子节点则为空)。测试准确率是在测试集上计算得到的。
阅读全文