写一个用dataset导入图片集并遍历图片的代码
时间: 2023-05-24 08:03:57 浏览: 112
关于图遍历的一个算法代码
A:
这是一个基于Python和Pytorch的代码,使用Dataset类和DataLoader类导入图片集并遍历图片。
```python
import os
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
class ImageDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.image_paths = os.listdir(data_dir)
self.transform = transform
def __getitem__(self, index):
image_path = self.image_paths[index]
image = Image.open(os.path.join(self.data_dir, image_path))
if self.transform:
image = self.transform(image)
return image
def __len__(self):
return len(self.image_paths)
if __name__ == '__main__':
data_dir = './image_folder'
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
dataset = ImageDataset(data_dir=data_dir, transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
for i, batch in enumerate(dataloader):
print(f"Batch {i}, shape = {batch.shape}")
```
解释:
首先定义了一个继承Pytorch中的Dataset类的ImageDataset类,用于加载图片数据。在初始化函数中,我们将指定的数据目录中的所有图片路径保存到image_paths实例变量中,并通过transform参数将传入的图像转换器保存到transform实例变量中。
在__getitem__方法中,我们获取指定索引的图像路径并打开图像。如果transform有值,则将图像转换为传入transform的函数链。最后,我们返回转换后的图像。
在__len__方法中,我们返回图像集中的图像数量。
在if __name__ == '__main__':语句块中,我们构建了一个ImageDataset实例,并传递了包含图像变换的transform参数。然后,我们创建一个DataLoader实例,用于生成批次数据。我们循环遍历DataLoader实例,并在每个批次中输出当前批次的形状。
这个简单的代码片段演示了如何使用Dataset和DataLoader类轻松地加载和遍历图像集数据。
阅读全文