构建自定义数据集,代码实现
时间: 2023-10-05 22:09:06 浏览: 34
下面是一个简单的代码示例,用于构建自定义数据集:
```python
import os
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data_dir, transform=None):
self.data_dir = data_dir
self.transform = transform
self.img_list = os.listdir(data_dir)
self.label_list = [self.get_label(img_name) for img_name in self.img_list]
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
img_path = os.path.join(self.data_dir, self.img_list[idx])
img = Image.open(img_path).convert('RGB')
label = self.label_list[idx]
if self.transform is not None:
img = self.transform(img)
return img, label
def get_label(self, img_name):
# 根据图片名字获取标签
return img_name.split('_')[0]
```
以上代码中,`CustomDataset` 类继承自 `torch.utils.data.Dataset` 类,通过实现 `__len__` 和 `__getitem__` 方法来定义自己的数据集。其中,`__len__` 方法返回数据集的长度,`__getitem__` 方法根据索引返回对应的数据和标签。
构建数据集的过程中,我们可以先将数据集中的图片文件名和标签读取到内存中,然后在 `__getitem__` 方法中根据索引读取对应的图片和标签。在读取图片时,我们可以使用 `PIL` 库来进行图片的格式转换和预处理。对于标签的处理,可以根据文件名或其他信息来进行解析。
最后,我们可以通过实现 `get_label` 方法来获取图片的标签,该方法根据图片名字来进行解析。需要根据实际情况对其进行修改和扩展。