pytorch如何加载数据集
时间: 2023-11-07 19:05:15 浏览: 88
pytorch 自定义数据集加载方法
5星 · 资源好评率100%
PyTorch提供了几种方法来加载数据集。其中一种常见的方式是使用torch.utils.data.Dataset类创建自定义数据集。你可以创建一个类,继承自torch.utils.data.Dataset,并重写__len__()和__getitem__()方法来定义你的数据集。__len__()方法应该返回数据集的大小,__getitem__()方法应该返回一个样本。例如,下面是一个自定义数据集类的示例:
```python
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index]
y = self.targets[index]
return x, y
```
另一种常见的方式是使用torch.utils.data.DataLoader类加载数据集。DataLoader类可以自动进行批处理、打乱和多线程加载。你可以将自定义数据集传递给DataLoader,并指定批大小、是否打乱数据集等参数。以下是一个使用DataLoader加载数据集的示例:
```python
from torch.utils.data import DataLoader
dataset = MyDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
```
此外,你还可以使用torchvision.datasets模块加载一些常见的数据集,例如MNIST、CIFAR等。这些数据集已经预处理好,并可以直接使用。你可以通过指定数据集的参数(如root、train、download等)来加载数据集。下面是一个使用torchvision.datasets加载MNIST数据集的示例:
```python
import torchvision.datasets as datasets
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
```
阅读全文