torch.utils.data详解,附代码
时间: 2023-08-25 18:26:57 浏览: 152
torch.utils.data 是 PyTorch 中用于数据处理和加载的模块。它提供了一系列工具来处理各种数据集,并且能够将数据转换为 PyTorch 中的 Tensor 对象。下面我们来详解一下其中的几个类和方法,并给出示例代码。
## Dataset
Dataset 是一个抽象类,用于表示数据集。我们需要继承这个类并实现 __len__ 和 __getitem__ 两个方法。其中,__len__ 方法应返回数据集的大小,__getitem__ 方法应根据给定的索引返回对应的数据样本。
示例代码:
```python
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data):
# data 是一个数据列表,每个元素是一个二元组 (input, label)
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
input, label = self.data[idx]
return input, label
```
## DataLoader
DataLoader 是用于将 Dataset 中的数据批量加载到内存中的类。我们可以设置 batch_size、shuffle、num_workers 等参数。
示例代码:
```python
from torch.utils.data import DataLoader
dataset = MyDataset(data)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
for batch in dataloader:
inputs, labels = batch
# 在这里对输入和标签进行处理
```
## Sampler
Sampler 是用于决定 DataLoader 如何从 Dataset 中采样数据的类。PyTorch 中提供了多种 Sampler,包括 RandomSampler、SequentialSampler、SubsetRandomSampler、WeightedRandomSampler 等。
示例代码:
```python
from torch.utils.data import RandomSampler
sampler = RandomSampler(dataset)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
for batch in dataloader:
inputs, labels = batch
# 在这里对输入和标签进行处理
```
## Subset
Subset 类表示 Dataset 的一个子集,可以通过 start 和 end 索引来指定子集的范围。
示例代码:
```python
from torch.utils.data import Subset
subset = Subset(dataset, start=0, end=100)
dataloader = DataLoader(subset, batch_size=32)
for batch in dataloader:
inputs, labels = batch
# 在这里对输入和标签进行处理
```
## BatchSampler
BatchSampler 类将数据划分为 batch,并采用指定的 Sampler 对每个 batch 进行采样。
示例代码:
```python
from torch.utils.data import BatchSampler
batch_sampler = BatchSampler(RandomSampler(dataset), batch_size=32, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)
for batch in dataloader:
inputs, labels = batch
# 在这里对输入和标签进行处理
```
这些类和方法提供了非常方便的数据处理和加载功能,可以帮助我们快速构建深度学习模型。
阅读全文