请列举下pytorch的dataset和dataloader的使用程序
时间: 2023-09-18 19:15:55 浏览: 95
好的,下面是使用PyTorch中的Dataset和Dataloader的示例程序:
首先,我们需要导入相关的库:
```python
import torch
from torch.utils.data import Dataset, DataLoader
```
接下来我们定义一个数据集类,这个类需要继承自`torch.utils.data.Dataset`,并实现`__len__`和`__getitem__`方法。这里我们假设我们有一个包含10个样本的数据集,每个样本由一个图片和一个标签组成:
```python
class MyDataset(Dataset):
def __init__(self):
self.data = [(torch.randn(3, 224, 224), i) for i in range(10)]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x, y = self.data[index]
return x, y
```
接下来,我们可以使用`MyDataset`类创建一个数据集对象,并使用`DataLoader`类创建一个数据加载器对象。在创建`DataLoader`对象时,我们需要指定`batch_size`和`shuffle`参数:
```python
my_dataset = MyDataset()
my_dataloader = DataLoader(my_dataset, batch_size=2, shuffle=True)
```
现在我们可以使用`my_dataloader`迭代数据集中的样本了,每个迭代器返回一个包含`batch_size`个样本的元组,其中第一个元素是一个大小为`(batch_size, 3, 224, 224)`的张量,代表`batch_size`个图片,第二个元素是一个大小为`(batch_size,)`的张量,代表`batch_size`个标签。我们可以使用下面的代码来迭代数据集:
```python
for x, y in my_dataloader:
print(x.shape, y.shape)
```
输出结果如下:
```
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
torch.Size([2, 3, 224, 224]) torch.Size([2])
```
这个程序演示了如何使用PyTorch中的Dataset和Dataloader来加载数据集,并迭代数据集中的样本。
阅读全文