if (strcmp(test_type, 'dis_eu') || strcmp(test_type, 'dis_seu')) save(['../EXEM_results/attr_EXEM_' dataset '_' test_type '_split' num2str(opt.ind_split) '_' feature_name '_'... norm_method '_C' num2str(C) '_nu' num2str(nu) '_gamma' num2str(gamma) '_pca_d' num2str(pca_d) '.mat']... , 'C', 'nu', 'gamma', 'pca_d', 'attr2'); else save(['../EXEM_results/EXEM_' dataset '_' test_type '_split' num2str(opt.ind_split) '_' feature_name '_'... norm_method '_C' num2str(C) '_nu' num2str(nu) '_gamma' num2str(gamma) '_pca_d' num2str(pca_d) '.mat']... , 'regressors', 'C', 'nu', 'gamma', 'pca_d', 'dis_eu', 'dis_seu', 'acc_eu', 'acc_seu'); end
时间: 2023-09-10 13:08:04 浏览: 330
这段代码根据测试类型的不同,将结果保存到不同的文件中。
1. 如果测试类型是`dis_eu`或`dis_seu`,则将结果保存到一个文件中。文件名包含了数据集名称、测试类型、数据集分割索引、特征名称、归一化方法、C、nu、gamma和pca_d的值。保存的变量包括C、nu、gamma、pca_d和attr2。
2. 如果测试类型不是`dis_eu`或`dis_seu`,则将结果保存到另一个文件中。文件名的命名规则与上述相同,保存的变量包括regressors、C、nu、gamma、pca_d、dis_eu、dis_seu、acc_eu和acc_seu。
这段代码的作用是根据测试类型将结果保存到不同的文件中。根据测试类型的不同,保存的变量也不同。
相关问题
%% testing if (strcmp(task, 'test')) if(isempty(direct_test) || length(direct_test) == 1) load(['../EXEM_CV_results/EXEM_classCV_' dataset '_split' num2str(opt.ind_split) '_' feature_name '_' norm_method '.mat'],... 'val_dis_eu', 'val_dis_seu', 'val_acc_eu', 'val_acc_seu', 'opt'); if (strcmp(test_type, 'dis_eu')) [loc_C, loc_nu, loc_gamma, loc_pca_d] = find_max(-val_dis_eu, direct_test); elseif (strcmp(test_type, 'dis_seu')) [loc_C, loc_nu, loc_gamma, loc_pca_d] = find_max(-val_dis_seu, direct_test); elseif (strcmp(test_type, 'acc_eu')) [loc_C, loc_nu, loc_gamma, loc_pca_d] = find_max(val_acc_eu, direct_test); elseif (strcmp(test_type, 'acc_seu')) [loc_C, loc_nu, loc_gamma, loc_pca_d] = find_max(val_acc_seu, direct_test); else disp('Wrong test type!'); return; end C = opt.C(loc_C(1)); nu = opt.nu(loc_nu(1)); gamma = opt.gamma(loc_gamma(1)); pca_d = opt.pca_d(loc_pca_d(1)); disp([loc_C(1), loc_nu(1), loc_gamma(1), loc_pca_d(1)]); else C = direct_test(1); nu = direct_test(2); gamma = direct_test(3); pca_d = direct_test(4); end Sig_Ytr = Sig_Y(unique(Ytr), :); Sig_Yte = Sig_Y(unique(Yte), :);
这段代码是用于进行测试的部分。
首先判断是否为测试任务(`task`为'test')。如果是,则执行以下操作:
1. 如果`direct_test`为空或长度为1,则加载之前保存的交叉验证结果和参数设置。使用`load`函数从MAT文件中加载`val_dis_eu`、`val_dis_seu`、`val_acc_eu`、`val_acc_seu`和`opt`变量。这些变量保存了交叉验证过程中的评估结果和参数设置。
2. 根据`test_type`的值,调用`find_max`函数找到在测试类型下具有最大值的索引。如果`test_type`为'dis_eu',则在-val_dis_eu中找到最大值的索引;如果为'dis_seu',则在-val_dis_seu中找到最大值的索引;如果为'acc_eu',则在val_acc_eu中找到最大值的索引;如果为'acc_seu',则在val_acc_seu中找到最大值的索引。
3. 根据找到的最大值的索引,获取对应的C、nu、gamma和pca_d参数值。
4. 显示找到的最大值的索引,用于输出结果。
5. 若`direct_test`不为空且长度为4,则直接使用`direct_test`中指定的C、nu、gamma和pca_d参数值。
6. 根据训练数据集标签Ytr,在Sig_Y中选择对应类别的特征向量,得到Sig_Ytr。
7. 根据测试数据集标签Yte,在Sig_Y中选择对应类别的特征向量,得到Sig_Yte。
这段代码的目的是根据测试任务的要求,选择合适的参数设置,并获取相应的训练和测试数据集的特征向量。如果直接指定了测试参数,则使用指定的参数进行测试;否则,根据交叉验证结果选择最优参数进行测试。
if(isempty(direct_test) || length(direct_test) == 1) load(['../EXEM_CV_results/EXEM_GZSL_classCV_' dataset '_split' num2str(opt.ind_split) '_' feature_name '_' norm_method '.mat'],... 'val_acc_eu', 'val_acc_seu', 'val_HM_eu', 'val_HM_seu', 'val_bias_eu', 'val_bias_seu', 'opt'); if (strcmp(test_type, 'acc_eu')) [loc_C, loc_nu, loc_gamma, loc_pca_d] = find_max(val_acc_eu, direct_test); fixed_bias = val_bias_eu(loc_C(1), loc_nu(1), loc_gamma(1), loc_pca_d(1)); elseif (strcmp(test_type, 'acc_seu')) [loc_C, loc_nu, loc_gamma, loc_pca_d] = find_max(val_acc_seu, direct_test); fixed_bias = val_bias_seu(loc_C(1), loc_nu(1), loc_gamma(1), loc_pca_d(1)); elseif (strcmp(test_type, 'HM_eu')) [loc_C, loc_nu, loc_gamma, loc_pca_d] = find_max(val_HM_eu, direct_test); fixed_bias = val_bias_eu(loc_C(1), loc_nu(1), loc_gamma(1), loc_pca_d(1)); elseif (strcmp(test_type, 'HM_seu')) [loc_C, loc_nu, loc_gamma, loc_pca_d] = find_max(val_HM_seu, direct_test); fixed_bias = val_bias_seu(loc_C(1), loc_nu(1), loc_gamma(1), loc_pca_d(1)); else disp('Wrong test type!'); return; end C = opt.C(loc_C(1)); nu = opt.nu(loc_nu(1)); gamma = opt.gamma(loc_gamma(1)); pca_d = opt.pca_d(loc_pca_d(1)); disp([loc_C(1), loc_nu(1), loc_gamma(1), loc_pca_d(1), fixed_bias]); else fixed_bias = direct_test(5); C = direct_test(1); nu = direct_test(2); gamma = direct_test(3); pca_d = direct_test(4); end
这段代码首先检查变量 direct_test 是否为空或长度是否为1。如果是,则加载之前保存的结果文件,并从中获取变量 val_acc_eu、val_acc_seu、val_HM_eu、val_HM_seu、val_bias_eu 和 val_bias_seu。根据 test_type 的不同,使用函数 find_max 在相应的结果矩阵中找到最大值,并得到对应的 loc_C、loc_nu、loc_gamma 和 loc_pca_d。根据 loc_C、loc_nu、loc_gamma 和 loc_pca_d 的值,从 val_bias_eu 或 val_bias_seu 中获取 fixed_bias。
如果 direct_test 不符合上述条件,则将 fixed_bias 设置为 direct_test 的第5个元素,并将 C、nu、gamma 和 pca_d 设置为 direct_test 的前4个元素。
最后,将定位到的 loc_C(1)、loc_nu(1)、loc_gamma(1)、loc_pca_d(1) 和 fixed_bias 显示出来,并将 C、nu、gamma 和 pca_d 分别赋值给对应的变量。
阅读全文