transform = transforms.Compose([ transforms.Grayscale(), transforms.Resize((28, 28)), transforms.ToTensor() ])如何加载自己的数据集
时间: 2024-05-06 18:19:03 浏览: 141
假设你的自己的数据集的文件夹路径为 `./my_dataset`,其中包含了训练集和测试集两个文件夹,每个文件夹下又有多张图片。
你可以使用 PyTorch 的 `ImageFolder` 类来加载自己的数据集,代码如下:
```python
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
# 定义数据增强和转换
transform = transforms.Compose([
transforms.Grayscale(),
transforms.Resize((28, 28)),
transforms.ToTensor()
])
# 加载训练集数据
train_dataset = ImageFolder('./my_dataset/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 加载测试集数据
test_dataset = ImageFolder('./my_dataset/test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
```
其中 `ImageFolder` 类会自动扫描指定路径下的文件夹和图片,并将它们转换成数据集。你只需要指定数据增强和转换的方式,并传入文件夹路径即可。
最后,你可以通过 `DataLoader` 类来批量读取数据。在这个例子中,我们定义了一个批次大小为 32 的数据加载器,用于训练集和测试集的加载。
阅读全文