pytorch dataloader 获取一个batch
时间: 2023-05-30 21:06:06 浏览: 572
可以使用`torch.utils.data.DataLoader`来获取一个batch。以下是一个示例:
```python
import torch
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return torch.Tensor(self.data[index])
# 创建数据集
data = [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
dataset = MyDataset(data)
# 创建数据加载器
batch_size = 2
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 获取一个batch
for batch in dataloader:
print(batch)
```
输出:
```
tensor([[ 1., 2., 3.],
[10., 11., 12.]])
tensor([[4., 5., 6.],
[7., 8., 9.]])
```
可以看到,我们成功地获取了一个batch。在这个示例中,我们创建了一个包含4个样本的数据集,然后使用`DataLoader`来创建数据加载器,其中`batch_size`设置为2。在`for`循环中,我们使用`dataloader`来获取一个batch,每个batch包含2个样本。由于我们设置了`shuffle=True`,因此每个batch中的样本顺序是随机的。
阅读全文