def load_data( *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, random_crop=False, random_flip=False, ): mage.save_img(image_path, images) if not data_dir: raise ValueError("unspecified data directory") all_files = list_image_files_recursively(data_dir) # 递归的去找到所有图片文件 classes = None if class_cond: # Assume classes are the first part of the filename, # before an underscore. class_names = [bf.basename(path).split("")[0] for path in all_files] sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} # class id 将每一类class做成一个整型变量 classes = [sorted_classes[x] for x in class_names] dataset = ImageDataset( image_size, all_files, classes=classes, shard=MPI.COMM_WORLD.Get_rank(), num_shards=MPI.COMM_WORLD.Get_size(), random_crop=random_crop, random_flip=random_flip, ) if deterministic: loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True ) else: loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True ) while True: yield from loader,data = load_data( data_dir="F://diffuse_model_ptcg//picture//test", batch_size=4, image_size=128, class_cond=False, )for batch in data:load_data当把数据集里所有的数据取完后,下一个batch将生成什么
时间: 2023-06-27 17:01:28 浏览: 86
当所有的数据集被取完后,下一个batch将会重新从头开始循环取数据,即重新开始一个新的epoch。这是因为在代码中设置了`drop_last=True`,即如果最后一个batch的样本数量不足batch_size,那么这个batch将被丢弃。因此,每个epoch的最后一个batch可能不满足batch_size,但每个epoch的第一个batch将始终是第一批数据。
相关问题
data = load_data( data_dir=r"F:\diffuse_model_ptcg\img_ed_ad_pc", batch_size=1, image_size=128, class_cond=False, deterministic=False, random_crop=False, random_flip=False, ) batch, cond = next(data) print(batch.tensor)
根据你提供的信息,我猜测你使用的是 PyTorch 框架中的 Torchvision 库,其中 `load_data` 函数是来自第三方库 `image_datasets`,并且 `load_data` 函数返回的是一个迭代器对象,可以通过 `next()` 函数来获取迭代器的下一个元素。
在调用 `next(data)` 之前,你需要先确保 `data` 是一个有效的迭代器对象。如果 `data` 不是一个有效的迭代器对象,或者迭代器中已经没有下一个元素,那么调用 `next(data)` 就会抛出一个异常,导致程序崩溃。
你需要检查一下 `data` 对象是否有效,以及迭代器中是否还有下一个元素。如果 `data` 对象无效,你需要检查一下 `load_data` 函数的实现,看看是否有问题。如果迭代器中已经没有下一个元素,你需要重新获取一个新的迭代器对象,然后再次调用 `next()` 函数来获取下一个元素。
def load_data( *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, random_crop=False, random_flip=False, ): mage.save_img(image_path, images) if not data_dir: raise ValueError("unspecified data directory") all_files = _list_image_files_recursively(data_dir) # 递归的去找到所有图片文件 classes = None if class_cond: # Assume classes are the first part of the filename, # before an underscore. class_names = [bf.basename(path).split("_")[0] for path in all_files] sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} # class id 将每一类class做成一个整型变量 classes = [sorted_classes[x] for x in class_names] dataset = ImageDataset( image_size, all_files, classes=classes, shard=MPI.COMM_WORLD.Get_rank(), num_shards=MPI.COMM_WORLD.Get_size(), random_crop=random_crop, random_flip=random_flip, ) if deterministic: loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True ) else: loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True ) while True: yield from loader,data = load_data( data_dir="F://diffuse_model_ptcg//picture//test", batch_size=4, image_size=128, class_cond=False, )for batch in data:load_data中batch是无尽的吗
是的,load_data函数返回的是一个Python生成器对象,使用yield from语句使其不断地产生新的数据批次。因此,使用for循环遍历batch时可以得到无尽的数据批次,直到程序被手动终止或者发生错误。需要注意的是,由于数据集的大小通常是有限的,因此在使用生成器对象时应该谨慎,避免产生过多的数据批次导致内存溢出等问题。
阅读全文