DataLoader对数据集进行shuffle=True的抽取时,如何能使得一个batch中不出现相同label的数据
时间: 2023-12-06 12:44:11 浏览: 138
pytorch中如何使用DataLoader对数据集进行批处理的方法
在使用DataLoader对数据集进行shuffle=True的抽取时,无法保证一个batch中不出现相同label的数据。不过,可以通过以下方式来实现这个需求:
1. 在数据集加载时,将数据按照label进行分类,得到一个字典,其中每个键对应一个标签类别,值为该类别下的所有数据。
2. 在每个epoch开始前,对字典中每个键对应的数据进行shuffle操作,以保证每个类别下的数据顺序不同。
3. 在DataLoader创建时,使用sampler参数指定一个自定义的Sampler,该Sampler会根据字典中每个键对应的数据数量,生成一个随机的采样顺序。
4. 在每个batch被抽取前,对于每个类别,先从该类别下的所有数据中选择一个样本,将其添加到batch中。然后,从剩余的数据中随机选择(batch_size-1)个样本,也添加到batch中。
通过以上方式,可以保证每个batch中都至少包含一个类别的数据,并且不会出现相同label的数据。
阅读全文