class Dataset(torch.utils.data.Dataset): def __init__(self, root): self.root = root if not os.path.exists(self.root): raise Exception("[!] {} not exists.".format(root)) # 这个类的目的是为了读取数据集,如果数据集不存在,就无法读取,因此在构造函数中进行判断,可以保证后续的代码能够正常运行 # sort file names 文件名排序 self.input_paths = sorted( glob(os.path.join(self.root, '{}/*_train.mat'.format("GB_data/Real/noise_data/" + Noise + "/train_data")))) self.label_paths = sorted( glob(os.path.join(self.root, '{}/*_lab.mat'.format("GB_data/Real/noise_data/" + Noise + "/train_lab")))) self.name = os.path.basename(root) # print(self.input_paths) # print(self.label_paths) if len(self.input_paths) == 0 or len(self.label_paths) == 0: raise Exception("No signal/labels are found in {}".format(self.root))
时间: 2024-04-03 20:30:33 浏览: 79
python torch.utils.data.DataLoader使用方法
5星 · 资源好评率100%
这段代码定义了一个继承自PyTorch中Dataset类的自定义数据集类。在构造函数__init__中,接收一个参数root,表示数据集的根目录。如果数据集不存在,就会抛出异常。然后使用glob函数获取数据集中所有训练数据文件的路径,并对文件名进行排序。接着,获取与训练数据对应的标签文件的路径,并同样进行排序。最后,判断训练数据和标签文件的数量是否一致,如果不一致,也会抛出异常。这个类的目的是为了方便后面的代码能够读取和处理数据集,从而进行深度学习模型的训练。
阅读全文