class DiabetesDataset(Dataset): def __init__(self, filepath): xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32) self.len = xy.shape[0] # shape(多少行,多少列) self.x_data = torch.from_numpy(xy[:, :-1]) self.y_data = torch.from_numpy(xy[:, [-1]]) def __getitem__(self, index): return self.x_data[index], self.y_data[index] def __len__(self): return self.len
时间: 2023-04-10 14:02:45 浏览: 150
糖尿病分类的数据集中Diabetes Dataset
这是一个名为DiabetesDataset的类,继承自Dataset类。它的构造函数__init__接受一个文件路径作为参数。在构造函数中,使用numpy库的loadtxt函数从指定路径的文件中读取数据,数据以逗号分隔,数据类型为float32。读取的数据被存储在变量xy中。self.len被赋值为xy的行数,即数据集的大小。
阅读全文