用pytorch写一个代码把数据集的图片都加载出来
时间: 2024-02-28 12:54:33 浏览: 33
好的,下面是一个用 PyTorch 加载数据集的代码示例:
```python
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
class MyDataset(Dataset):
def __init__(self, img_list, transform=None):
self.img_list = img_list
self.transform = transform
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img_path = self.img_list[idx]
img = Image.open(img_path).convert('RGB')
if self.transform:
img = self.transform(img)
return img
# 定义数据集路径
data_path = '/path/to/your/dataset'
# 定义数据预处理方式
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# 加载数据集
dataset = MyDataset(img_list=sorted(glob.glob(os.path.join(data_path, '*.jpg'))), transform=transform)
# 定义数据加载器
batch_size = 32
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
# 遍历数据集
for imgs in data_loader:
# imgs 的 shape 为 [batch_size, 3, 224, 224]
# 这里可以对 imgs 进行操作,比如输入到模型中进行推理
pass
```
该代码定义了一个 `MyDataset` 类,用于加载数据集中的图片。在 `MyDataset` 类中,`__init__` 方法用于初始化数据集,`__len__` 方法返回数据集大小,`__getitem__` 方法用于获取数据集中的一张图片。
我们使用 PyTorch 内置的 `DataLoader` 类来实现数据的批量加载和多线程读取,可以设置 `batch_size` 参数控制每次读取的图片数量,设置 `num_workers` 参数控制多线程读取的线程数。