% 导入预训练的model opts.modelPath = fullfile('..','models','imagenet-vgg-verydeep-16.mat'); [opts, varargin] = vl_argparse(opts, varargin) ; opts.numFetchThreads = 12 ; opts.lite = false ; opts.imdbPath = fullfile(opts.expDir, 'imdb.mat'); opts.train = struct() ; opts.train.gpus = []; opts.train.batchSize = 8 ; opts.train.numSubBatches = 4 ; opts.train.learningRate = 1e-4 * [ones(1,10), 0.1*ones(1,5)]; opts = vl_argparse(opts, varargin) ; if ~isfield(opts.train, 'gpus'), opts.train.gpus = []; end; % ------------------------------------------------------------------------- % Prepare model % ------------------------------------------------------------------------- net = load(opts.modelPath); % 修改一下这个model net = prepareDINet(net,opts); % ------------------------------------------------------------------------- % Prepare data % ------------------------------------------------------------------------- % 准备数据格式 if exist(opts.imdbPath,'file') imdb = load(opts.imdbPath) ; else imdb = cnn_image_setup_data('dataDir', opts.dataDir, 'lite', opts.lite) ; mkdir(opts.expDir) ; save(opts.imdbPath, '-struct', 'imdb') ; end imdb.images.set = imdb.images.sets; % Set the class names in the network net.meta.classes.name = imdb.classes.name ; net.meta.classes.description = imdb.classes.name ; % % 求训练集的均值 imageStatsPath = fullfile(opts.expDir, 'imageStats.mat') ; if exist(imageStatsPath) load(imageStatsPath, 'averageImage') ; else averageImage = getImageStats(opts, net.meta, imdb) ; save(imageStatsPath, 'averageImage') ; end % % 用新的均值改变均值 net.meta.normalization.averageImage = averageImage; % ------------------------------------------------------------------------- % Learn % ------------------------------------------------------------------------- % 索引训练集==1 和测试集==3 opts.train.train = find(imdb.images.set==1) ; opts.train.val = find(imdb.images.set==3) ; % 训练 [net, info] = cnn_train_dag(net, imdb, getBatchFn(opts, net.meta), ... 'expDir', opts.expDir, ... opts.train) ;
时间: 2023-06-19 08:10:01 浏览: 216
inception_resnet_v2 预训练模型
这段代码是一个深度学习的训练过程,主要包括准备模型、准备数据、训练模型三个步骤。其中,模型使用的是预训练的imagenet-vgg-verydeep-16模型,数据使用的是经过预处理后的imdb数据集,训练过程使用的是cnn_train_dag函数。具体流程如下:
1. 导入预训练的模型,将其存储在net变量中;
2. 使用prepareDINet函数对模型进行修改,以适应当前任务;
3. 准备数据,如果之前已经生成了imdb文件,则直接载入,否则使用cnn_image_setup_data函数生成并保存imdb文件;
4. 设置模型的分类名称和描述;
5. 计算训练集的均值,如果已经计算过,则从文件中读取,否则使用getImageStats函数计算并保存;
6. 使用新的均值修改模型的normalization参数;
7. 确定训练集和测试集的索引;
8. 调用cnn_train_dag函数进行训练,并将训练结果存储在net和info两个变量中。
阅读全文