我的每个文件夹是不同标签
时间: 2024-12-18 11:30:58 浏览: 7
明白了,如果你的每个子文件夹代表不同的标签,那么`CustomImageDataset`可以稍作调整以反映这种分类结构。在这种情况下,子文件夹名可以视为标签名称,而文件夹内图片对应相应的标签。下面是更新后的类定义:
```python
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import os
class CustomTaggedImageDataset(Dataset):
def __init__(self, root_dir, tag_list, transform=None):
self.root_dir = root_dir
self.tag_list = tag_list
self.transform = transform
self.filepaths = {tag: [] for tag in tag_list}
for tag in tag_list:
folder_path = os.path.join(root_dir, tag)
for filename in os.listdir(folder_path):
img_path = os.path.join(folder_path, filename)
if os.path.isfile(img_path) and filename.endswith(('.jpg', '.jpeg', '.png')):
self.filepaths[tag].append(img_path)
def __len__(self):
return sum(len(v) for v in self.filepaths.values())
def __getitem__(self, idx):
label = self.find_label_for_idx(idx)
path = self.filepaths[label][idx]
image = Image.open(path).convert('RGB') # 假设图片是RGB格式
if self.transform:
image = self.transform(image)
return image, label
def find_label_for_idx(self, idx):
current_idx = 0
for label, files in self.filepaths.items():
if current_idx + len(files) > idx:
return label
current_idx += len(files)
raise IndexError(f"Index {idx} out of range for dataset with {self.__len__()}")
# 使用自定义数据集
tag_list = ['tag1', 'tag2', 'tag3', 'tag4', 'tag5'] # 替换为你的实际标签
dataset = CustomTaggedImageDataset(root_dir='your_directory_path', tag_list=tag_list, transform=transforms.ToTensor())
```
在这个版本中,我们使用`os.listdir`遍历每个标签文件夹,收集图片文件路径。`__getitem__`方法会根据索引找到相应的标签和图像路径。`find_label_for_idx`函数确定索引所对应的标签。
阅读全文