def data_split(data, rate): train_l = int(len(data) * rate) test_l = len(data) - train_l """打乱数据集并且划分""" train_set, test_set = torch.utils.data.random_split(data, [train_l, test_l]) return train_set, test_set
时间: 2024-04-27 20:24:12 浏览: 67
数据集分割train和test程序
这是一个数据集划分函数,用于将数据集按照给定比例划分为训练集和测试集。
输入参数 data 是一个 PyTorch 数据集对象,rate 是训练集所占比例,取值范围为 (0, 1)。
在函数内部,首先计算出训练集和测试集的样本数量,然后调用 PyTorch 中的 torch.utils.data.random_split() 函数将数据集随机划分为训练集和测试集,这个函数返回的是两个新的数据集对象。最后将训练集和测试集分别返回。
阅读全文