from torch.utils.data.sampler import SubsetRandomSampler
时间: 2023-07-23 11:57:38 浏览: 217
python torch.utils.data.DataLoader使用方法
5星 · 资源好评率100%
SubsetRandomSampler是PyTorch中的一个采样器类,用于在给定数据集的子集上进行随机采样。它可以用于创建mini-batch训练数据集和验证数据集。
在使用SubsetRandomSampler时,你需要提供一个索引列表,该列表表示数据集中的样本索引。然后SubsetRandomSampler将根据这些索引随机选择样本。
下面是一个使用SubsetRandomSampler的示例:
```python
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import SubsetRandomSampler
# 创建自定义数据集
class CustomDataset(Dataset):
def __init__(self):
self.data = [1, 2, 3, 4, 5]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 创建数据集实例
dataset = CustomDataset()
# 设定训练集索引
train_indices = [0, 1, 2]
# 设定验证集索引
val_indices = [3, 4]
# 创建采样器实例
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
# 创建数据加载器
train_loader = DataLoader(dataset, batch_size=2, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=2, sampler=val_sampler)
# 在训练集上进行迭代
for batch in train_loader:
print(batch)
# 在验证集上进行迭代
for batch in val_loader:
print(batch)
```
在上面的示例中,我们创建了一个自定义的数据集CustomDataset,并使用SubsetRandomSampler将数据集划分为训练集和验证集。然后,我们可以使用DataLoader加载数据集,并通过迭代器访问数据集的mini-batch。
希望这个例子能帮助你理解SubsetRandomSampler的用法。如果还有其他问题,请继续提问。
阅读全文