神经网络深度学习CNN MATLAB代码
时间: 2023-07-05 15:29:22 浏览: 105
好的,以下是一个简单的CNN的MATLAB代码示例,用于对MNIST手写数字数据集进行分类:
```
% 加载MNIST数据集
load mnist.mat
% 创建卷积神经网络模型
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(5, 20)
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2, 'Stride', 2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
% 设置训练选项
options = trainingOptions('adam', ...
'InitialLearnRate', 0.001, ...
'MaxEpochs', 10, ...
'MiniBatchSize', 128, ...
'Shuffle', 'every-epoch', ...
'Verbose', false, ...
'Plots', 'training-progress');
% 训练模型
net = trainNetwork(trainX, trainY, layers, options);
% 预测测试集数据
YTest = classify(net, testX);
% 计算准确率
accuracy = sum(YTest == testY) / numel(testY);
fprintf('准确率: %.2f%%\n', accuracy*100);
```
这段代码中,首先加载了MNIST数据集,并创建了一个包含卷积层、批归一化层、ReLU激活层、池化层、全连接层、softmax层和分类层的CNN模型。然后,使用trainNetwork函数训练模型,并使用classify函数预测测试集数据。最后,计算准确率并输出结果。
阅读全文