pytorch加载数据集
时间: 2023-10-18 18:25:46 浏览: 122
PyTorch加载自己的数据集实例详解
在PyTorch中加载数据集通常有两种常见的方法:使用自定义数据集和使用预定义数据集。
1. 使用自定义数据集:
- 创建一个新的Python类,继承`torch.utils.data.Dataset`,并实现`__len__`和`__getitem__`方法。`__len__`返回数据集的大小,`__getitem__` 根据给定索引返回样本。
- 在`__init__`方法中,根据需求加载数据集并对其进行预处理。
- 可以使用PyTorch提供的各种数据转换方法(例如`torchvision.transforms`)来对数据进行预处理。
- 在训练代码中,实例化自定义数据集类,并使用`torch.utils.data.DataLoader`将数据加载到训练循环中。
下面是一个简单的自定义数据集加载示例:
```python
import torch
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
# 在这里进行数据预处理
return torch.Tensor(sample)
# 假设有一个包含样本的列表
data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
dataset = CustomDataset(data)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
for batch in dataloader:
# 在这里执行训练循环
print(batch)
```
2. 使用预定义数据集:
- PyTorch提供了一些预定义的数据集,如`torchvision.datasets`模块中的MNIST、CIFAR10等。
- 可以使用预定义数据集的构造函数来加载数据集,并根据需要进行转换和预处理。
- 同样,可以使用`torch.utils.data.DataLoader`将数据加载到训练循环中。
下面是一个预定义数据集加载示例:
```python
import torch
import torchvision
from torchvision import transforms
# 定义数据转换和预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
# 使用DataLoader加载数据集
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
for batch in train_dataloader:
# 在这里执行训练循环
images, labels = batch
print(images.shape, labels.shape)
```
这些是基本的加载数据集的方法,你可以根据自己的需求进行修改和扩展。
阅读全文