function untitled() load('D:\mat格式的MNIST数据\test_labels.mat') load('D:\mat格式的MNIST数据\train_images.mat') load('D:\mat格式的MNIST数据\train_labels.mat') load('D:\mat格式的MNIST数据\test_images.mat') train_num = 600; test_num = 200; %训练数据,图像转向量 data_train = mat2vector(train_images(:,:,1:train_num),train_num); data_test = mat2vector(test_images(:,:,1:test_num),test_num); % 处理训练数据,防止后验概率为0 [data_train,position] = fun(data_train,train_labels1(1:train_num)'); % 处理测试数据 for rows = 1:10 data_test(:,position{1,rows})=[]; end %模型部分 Mdl = fitcnb(data_train,train_labels1(1:train_num)); %测试结果 result = predict(Mdl,data_test); result = result.'; xlabel=[0,1,2,3,4,5,6,7,8,9]; resultbar = [0,0,0,0,0,0,0,0,0,0]; testbar = [0,0,0,0,0,0,0,0,0,0]; for i = 1:test_num temp1=result(i); temp1=temp1+1; resultbar(temp1)=resultbar(temp1)+1; temp2=test_labels1(i); temp2=temp2+1; testbar(temp2)=testbar(temp2)+1; end bar(xlabel, [resultbar' testbar']); % 整体正确率 acc = 0.; for i = 1:test_num if result(i)==test_labels1(i) acc = acc+1; end end title('精确度为:',(acc/test_num)*100) end function [output,position] = fun(data,label) position = cell(1,10); %创建cell存储每类中删除的列标 for i = 0:9 temp = []; pos = []; for rows = 1:size(data,1) if label(rows)==i temp = [temp;data(rows,:)]; end end for cols = 1:size(temp,2) var_data = var(temp(:,cols)); if var_data==0 pos = [pos,cols]; end end position{i+1} = pos; data(:,pos)=[]; end output = data; end function [data_]= mat2vector(data,num) [row,col,~] = size(data); data_ = zeros(num,row*col); for page = 1:num for rows = 1:row for cols = 1:col data_(page,((rows-1)*col+cols)) = im2double(data(rows,cols,page)); end end end end 将画图部分重写,完成相同功能
时间: 2024-03-07 14:49:57 浏览: 139
Attachments_2015423.zip_Attachments
好的,以下是将画图部分重写的代码:
```matlab
% 统计预测结果和真实结果的数量
result_count = histcounts(result, 0:10);
test_count = histcounts(test_labels1, 0:10);
% 绘制柱状图
figure;
bar(0:9, [result_count', test_count']);
xlabel('数字');
ylabel('数量');
legend('预测结果', '真实结果');
% 计算并输出整体正确率
accuracy = sum(result == test_labels1) / test_num;
title(['整体正确率为:', num2str(accuracy * 100), '%']);
```
改动后,完整代码如下:
```matlab
function untitled()
load('D:\mat格式的MNIST数据\test_labels.mat')
load('D:\mat格式的MNIST数据\train_images.mat')
load('D:\mat格式的MNIST数据\train_labels.mat')
load('D:\mat格式的MNIST数据\test_images.mat')
train_num = 600;
test_num = 200;
% 训练数据,图像转向量
data_train = mat2vector(train_images(:,:,1:train_num),train_num);
data_test = mat2vector(test_images(:,:,1:test_num),test_num);
% 处理训练数据,防止后验概率为0
[data_train, position] = fun(data_train,train_labels1(1:train_num)');
% 处理测试数据
for rows = 1:10
data_test(:,position{1,rows})=[];
end
% 模型部分
Mdl = fitcnb(data_train,train_labels1(1:train_num));
% 测试结果
result = predict(Mdl,data_test);
result = result.';
% 统计预测结果和真实结果的数量
result_count = histcounts(result, 0:10);
test_count = histcounts(test_labels1, 0:10);
% 绘制柱状图
figure;
bar(0:9, [result_count', test_count']);
xlabel('数字');
ylabel('数量');
legend('预测结果', '真实结果');
% 计算并输出整体正确率
accuracy = sum(result == test_labels1) / test_num;
title(['整体正确率为:', num2str(accuracy * 100), '%']);
end
function [output,position] = fun(data,label)
position = cell(1,10); % 创建cell存储每类中删除的列标
for i = 0:9
temp = [];
pos = [];
for rows = 1:size(data,1)
if label(rows)==i
temp = [temp;data(rows,:)];
end
end
for cols = 1:size(temp,2)
var_data = var(temp(:,cols));
if var_data==0
pos = [pos,cols];
end
end
position{i+1} = pos;
data(:,pos)=[];
end
output = data;
end
function [data_] = mat2vector(data,num)
[row,col,~] = size(data);
data_ = zeros(num,row*col);
for page = 1:num
for rows = 1:row
for cols = 1:col
data_(page,((rows-1)*col+cols)) = im2double(data(rows,cols,page));
end
end
end
end
```
阅读全文