def prepare_files(train_path, mixing, order, labels_dic, nb_groups, nb_cl, nb_val): files=os.listdir(train_path) prefix = np.array([file_i.split("_")[0] for file_i in files]) labels_old = np.array([mixing[labels_dic[i]] for i in prefix]) files_train = [] files_valid = [] for _ in range(nb_groups): files_train.append([]) files_valid.append([]) files=np.array(files) for i in range(nb_groups): for i2 in range(nb_cl): tmp_ind=np.where(labels_old == order[nb_cl*i+i2])[0] np.random.shuffle(tmp_ind) files_train[i].extend(files[tmp_ind[0:len(tmp_ind)-nb_val]]) files_valid[i].extend(files[tmp_ind[len(tmp_ind)-nb_val:]]) return files_train, files_valid
时间: 2023-06-14 07:02:30 浏览: 126
Phai_nonlinear.zip_FOUR WAVE_Four-wave mixing_fiber_four wave mi
这段代码是用来准备数据集的。函数接受训练数据的路径(train_path)、标签混淆(mixing)、标签顺序(order)、标签字典(labels_dic)、组数(nb_groups)、每组的类别数(nb_cl)和验证集大小(nb_val)等参数。
首先,获取训练数据集中所有文件的文件名(files)和文件名前缀(prefix)。然后,根据文件名前缀获取每个文件对应的标签(labels_old),并将标签进行混淆(mixing)。接着,定义训练文件列表(files_train)和验证文件列表(files_valid)。对于每一组数据(nb_groups),先为该组数据创建一个空列表,然后对于该组数据中的每个类别(nb_cl),将该类别对应的文件列表按照标签顺序(order)进行排序,并将其分为训练集和验证集。最后返回训练文件列表和验证文件列表。
阅读全文