torch.utils.data.DataLoader遍历代码
时间: 2023-10-16 09:24:40 浏览: 41
以下是使用torch.utils.data.DataLoader遍历数据集的示例代码:
```python
import torch
from torch.utils.data import DataLoader, Dataset
# 创建自定义数据集
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 创建数据集
data = list(range(100))
# 创建数据集实例
dataset = CustomDataset(data)
# 创建数据加载器实例
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# 遍历数据集
for batch in dataloader:
print(batch)
```
在上述代码中,首先创建了一个自定义数据集`CustomDataset`,并实现了`__getitem__`和`__len__`方法。然后,将数据集实例化,并传递给`DataLoader`,设置了批量大小为10,且打乱了数据集顺序。最后,使用一个简单的for循环遍历数据集,每次获取一个批次的数据,并将其打印出来。
相关问题
torch.utils.data.dataloader样例代码
下面是一个使用`torch.utils.data.dataloader`的样例代码:
```python
import torch
from torch.utils.data import Dataset, DataLoader
# 自定义数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 实例化数据集
data = list(range(100))
dataset = MyDataset(data)
# 实例化数据加载器
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
# 遍历数据加载器
for batch in dataloader:
print(batch)
```
该代码中,我们首先定义了一个自定义数据集类`MyDataset`,并实现了`__getitem__`和`__len__`方法。然后我们将数据集实例化,并传入`DataLoader`中,设置了`batch_size`为10,`shuffle`为True。最后我们遍历数据加载器并打印每个batch的数据。
torch.utils.data。Data.DataLoader
`torch.utils.data`是PyTorch中用于处理数据的一个模块,提供了一些常用的数据集类和数据处理工具函数。其中,`DataLoader`是一个数据加载器,可以将一个数据集封装为一个可迭代的数据加载器,方便地进行批量数据读取。
`DataLoader`的常见用法是将数据集传入,然后使用`batch_size`参数指定每个批次的数据量,使用`shuffle`参数指定是否随机打乱数据集。
下面是一个示例代码:
```python
import torch
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
dataset = MyDataset(data)
dataloader = DataLoader(
dataset,
batch_size=3,
shuffle=True
)
for batch in dataloader:
print(batch)
```
这段代码中,我们首先定义了一个自定义的数据集类`MyDataset`,并将一个列表作为数据集传入。然后我们使用`DataLoader`将数据集封装为一个可迭代的数据加载器,并指定每个批次的数据量为3,设置`shuffle=True`表示每次迭代时都会随机打乱数据集。最后,我们使用`for`循环遍历数据加载器,每次迭代都会返回一个大小为3的批次数据。
输出结果为:
```
tensor([10, 9, 2])
tensor([4, 7, 3])
tensor([8, 1, 6])
tensor([5])
```
可以看到,输出结果是一个个大小为3的批次数据,最后一个批次只有一个数据。