使用Dataset与Dataloader加载数据 - 首次加载数据使用直接下载的方式,数据存放至./data文件夹
时间: 2024-02-09 08:11:03 浏览: 75
利用torch.utils.data.Dataset自定义数据加载类
3星 · 编辑精心推荐
好的,以下是代码实现:
```python
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
# 定义自定义的MINST数据集
class CustomMNISTDataset(Dataset):
def __init__(self, train=True, transform=None):
self.train = train
self.transform = transform
if self.train:
self.data, self.labels = torch.utils.data.Subset(torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()), list(range(0, 50000)))
else:
self.data, self.labels = torch.utils.data.Subset(torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor()), list(range(50000, 60000)))
def __getitem__(self, index):
img, target = self.data[index], self.labels[index]
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.data)
# 定义数据预处理方式
transform = transforms.Compose([
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载训练集和测试集
train_dataset = CustomMNISTDataset(train=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_dataset = CustomMNISTDataset(train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)
```
注意:这段代码在第一次运行时会直接下载MINST数据集,存放至./data文件夹下。如果之后再次运行需要手动删除./data文件夹下的所有文件。同时,为了减小训练集大小,代码中只使用前50000个样本作为训练集,后10000个样本作为测试集。如果需要使用完整的训练集,可以将`self.data, self.labels`的赋值语句改为`torchvision.datasets.MNIST(root='./data', train=self.train, download=True, transform=transforms.ToTensor())`。
阅读全文