基于MATLAB的MNIST手写体数字识别
时间: 2025-01-03 20:37:00 浏览: 10
### 使用MATLAB实现MNIST手写体数字识别
#### 加载并预处理MNIST数据集
在开始之前,需要加载和准备MNIST数据集。该数据集被称为手写数字的MNIST数据库[^3]。
```matlab
% 下载并解压MNIST数据集
url = 'http://yann.lecun.com/exdb/mnist/';
files = {'train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz',...
't10k-images-ubyte.gz', 't10k-labels-ubyte.gz'};
for i = 1:length(files)
if ~exist(files{i}(1:end-3), 'file')
websave(files{i}, [url files{i}]);
gunzip(files{i});
end
end
% 定义读取函数来解析二进制文件
function [images, labels] = load_mnist_data(image_file, label_file)
fid = fopen(label_file,'r','b');
magic_num = fread(fid,1,'int32',0,'ieee-be'); %#ok<UNRCH>
size_labels = fread(fid,1,'int32',0,'ieee-be');
labels = fread(fid,[size_labels,1],'unsigned char')';
fid = fopen(image_file,'r','b');
magic_num = fread(fid,1,'int32',0,'ieee-be'); %#ok<UNRCH>
size_images = fread(fid,1,'int32',0,'ieee-be');
rows = fread(fid,1,'int32',0,'ieee-be');
cols = fread(fid,1,'int32',0,'ieee-be');
images = reshape(fread(fid,size_images*rows*cols,'unsigned char'),...
cols,rows,size_images);
fclose('all');
end
[trainImages, trainLabels] = load_mnist_data('train-images-ubyte', ...
'train-ubyte');
disp(['Training set dimensions:', num2str(size(trainImages))])
disp(['Test set dimensions:' ,num2str(size(testImages))])
% 显示一些样本图像
figure;
montage(mat2cell(double(trainImages(:,:,1:8)),28,28,ones(1,8)));
title('Sample Training Images from MNIST Dataset')
```
#### 构建神经网络模型
接下来构建一个多层感知器(MLP),用于对手写数字进行分类。
```matlab
hiddenLayerSize = 100;
net = patternnet(hiddenLayerSize);
% 设置训练参数
net.trainParam.showWindow = false;
net.trainParam.epochs = 50;
net.trainParam.goal = 1e-6;
% 准备输入向量形式的数据
trainData = double(reshape(trainImages,[],size(trainImages,3)))'/255;
testData = double(reshape(testImages,[],size(testImages,3)))'/255;
% 训练网络
[trainedNet,tr] = train(net,trainData,trainLabels);
% 测试性能
predictedLabels = vec2ind(trainedNet(testData));
accuracy = sum(predictedLabels == testLabels)/length(testLabels)*100;
fprintf('Accuracy on the testing dataset is %.2f%%.\n', accuracy);
```
#### 创建图形界面(GUI)供用户绘制自己的数字
为了使应用程序更加友好,可以创建一个简单的GUI让用户能够画出自定义的手写字母,并通过已训练好的模型预测其类别[^1]。
```matlab
function create_gui()
f = uifigure('Position',[300 300 400 400],...
'Name','Handwritten Digit Recognition GUI');
ax = uiaxes(f,...
'Units','normalized',...
'Position',[0.1 0.1 0.7 0.7]);
btnClear = uibutton(f,'Text','Clear',...
'Position',[290 200 80 22],...
'ButtonPushedFcn',{@clear_canvas,f});
btnSave = uibutton(f,'Text','Predict',...
'Position',[290 150 80 22],...
'ButtonPushedFcn',{@predict_digit,f,ax});
% 初始化绘图区域状态变量
guidata(f,struct('drawingMode',false));
% 绘制事件监听器
addlistener(ax,'ButtonDownFcn',@(src,event)start_drawing(src,event,f));
addlistener(ax,'WindowButtonMotionFcn',@(src,event)motion_handler(src,event,f));
addlistener(ax,'WindowButtonUpFcn', @(src,event)stop_drawing(src,event,f));
end
function start_drawing(~,eventdata,hObject)
handles = guidata(hObject);
handles.drawingMode = true;
guidata(hObject,handles);
end
function motion_handler(~, eventdata, hObject)
handles = guidata(hObject);
if handles.drawingMode
pos = get(gca,'CurrentPoint');
line([pos(1,1);pos(2,1)], [pos(1,2);pos(2,2)],'Color','w',...
'LineWidth',10,'Parent',gca);
end
end
function stop_drawing(~, eventdata, hObject)
handles = guidata(hObject);
handles.drawingMode = false;
guidata(hObject,handles);
end
function clear_canvas(~,~,hObject)
cla(hObject.Children{1},'reset');
end
function predict_digit(~,~,hObject,axesHandle)
img = flipud(getframe(axesHandle).cdata);
grayImg = rgb2gray(img);
bwImg = imbinarize(grayImg);
resizedImg = imresize(bwImg,[28 28]);
inputVector = double(resizedImg(:))/255;
prediction = trainedNet(inputVector)';
predictedLabel = find(prediction==max(prediction))-1;
msgbox(sprintf('The recognized digit is:%d',predictedLabel),'Prediction Result');
end
```
阅读全文