def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.1, test_scale=0.1): print("开始数据集划分") class_names = os.listdir(src_data_folder) split_names = ['train', 'val', 'test'] for split_name in split_names: split_path = os.path.join(target_data_folder, split_name) if os.path.isdir(split_path): pass else: os.mkdir(split_path) for class_name in class_names: class_split_path = os.path.join(split_path, class_name) if os.path.isdir(class_split_path): pass else: os.mkdir(class_split_path)
时间: 2023-07-14 20:12:57 浏览: 81
这段代码实现了一个数据集划分的功能,将源数据文件夹中的数据按照一定比例划分为训练集、验证集和测试集,并将划分后的数据保存在目标数据文件夹中。其中,train_scale、val_scale 和 test_scale 分别表示训练集、验证集和测试集的比例。class_names 表示源数据文件夹中的分类名称,split_names 表示要划分的数据集名称。对于每个数据集名称,先在目标数据文件夹中创建一个相应的文件夹,然后针对每个分类名称,再在相应的数据集文件夹中创建一个相应的分类文件夹。最终实现了将数据按比例划分并保存到目标文件夹的功能。
相关问题
为每句代码做注释:for class_name in class_names: current_class_data_path = os.path.join(src_data_folder, class_name) current_all_data = os.listdir(current_class_data_path) current_data_length = len(current_all_data) current_data_index_list = list(range(current_data_length)) random.shuffle(current_data_index_list) train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name) val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name) test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name) train_stop_flag = current_data_length * train_scale val_stop_flag = current_data_length * (train_scale + val_scale) current_idx = 0 train_num = 0 val_num = 0 test_num = 0 for i in current_data_index_list: src_img_path = os.path.join(current_class_data_path, current_all_data[i]) if current_idx <= train_stop_flag: copy2(src_img_path, train_folder) train_num = train_num + 1 elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag): copy2(src_img_path, val_folder) val_num = val_num + 1 else: copy2(src_img_path, test_folder) # print("{}复制到了{}".format(src_img_path, test_folder)) test_num = test_num + 1 current_idx = current_idx + 1
# 循环遍历每个类别的文件夹
for class_name in class_names:
# 拼接当前类别的数据路径
current_class_data_path = os.path.join(src_data_folder, class_name)
# 获取当前类别的所有数据文件名
current_all_data = os.listdir(current_class_data_path)
# 获取当前类别的数据数量
current_data_length = len(current_all_data)
# 生成当前类别数据的索引列表
current_data_index_list = list(range(current_data_length))
# 随机打乱当前类别数据的索引列表
random.shuffle(current_data_index_list)
# 拼接训练集、验证集、测试集的路径
train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)
val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)
test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)
# 计算训练集、验证集、测试集在当前类别中的截止点
train_stop_flag = current_data_length * train_scale
val_stop_flag = current_data_length * (train_scale + val_scale)
# 初始化当前类别的数据索引、训练集数量、验证集数量、测试集数量
current_idx = 0
train_num = 0
val_num = 0
test_num = 0
# 循环遍历当前类别的数据索引列表,将数据复制到对应的训练集、验证集、测试集文件夹中
for i in current_data_index_list:
src_img_path = os.path.join(current_class_data_path, current_all_data[i])
# 如果当前索引在训练集截止点之前,则将数据复制到训练集
if current_idx <= train_stop_flag:
copy2(src_img_path, train_folder)
train_num = train_num + 1
# 如果当前索引在训练集截止点和验证集截止点之间,则将数据复制到验证集
elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
copy2(src_img_path, val_folder)
val_num = val_num + 1
# 如果当前索引在验证集截止点之后,则将数据复制到测试集
else:
copy2(src_img_path, test_folder)
# 打印当前数据的拷贝信息
# print("{}复制到了{}".format(src_img_path, test_folder))
test_num = test_num + 1
# 更新当前类别的数据索引和数量统计信息
current_idx = current_idx + 1
for class_name in class_names: current_class_data_path = os.path.join(src_data_folder, class_name) current_all_data = os.listdir(current_class_data_path) current_data_length = len(current_all_data) current_data_index_list = list(range(current_data_length)) random.shuffle(current_data_index_list) train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name) val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name) test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name) train_stop_flag = current_data_length * train_scale val_stop_flag = current_data_length * (train_scale + val_scale) current_idx = 0 train_num = 0 val_num = 0 test_num = 0 解释
这段代码是一个循环,用于处理一组类别(class_names)中的数据。
在循环中,首先通过os.path.join函数将源数据文件夹(src_data_folder)路径和当前类别名称(class_name)拼接起来,得到当前类别数据的路径(current_class_data_path)。
然后,使用os.listdir函数获取当前类别数据路径下的所有文件和文件夹的名称,将结果保存在current_all_data列表中。
接下来,通过len函数计算current_all_data列表的长度,得到当前类别数据的总数量,将结果保存在current_data_length变量中。
然后,使用list函数和range函数结合,创建一个包含0到current_data_length-1的整数列表,并将结果保存在current_data_index_list变量中。这个列表用于后续对数据进行随机打乱。
随后,通过os.path.join函数将目标数据文件夹(target_data_folder)路径和训练集('train')、验证集('val')、测试集('test')子文件夹路径拼接起来,并再次使用os.path.join函数将当前类别名称拼接到各个子文件夹路径末尾,得到训练集、验证集和测试集对应的文件夹路径。
然后,根据训练集、验证集和测试集所占比例(train_scale、val_scale)和当前类别数据的总数量(current_data_length),计算出相应的停止位置(train_stop_flag、val_stop_flag)。
最后,初始化一些计数器和索引变量(current_idx、train_num、val_num、test_num),用于记录当前已处理的数据数量。
这段代码的目的是对一组类别中的数据进行处理,包括获取数据路径、计算数据数量、打乱数据索引,并确定训练集、验证集和测试集的文件夹路径以及停止位置。
阅读全文