帮我用pytorch写一个读取NUE-WIDE数据集的图片和标签的代码,分测试集和数据集
时间: 2024-02-13 11:00:55 浏览: 112
好的,以下是用PyTorch读取NUE-WIDE数据集的图片和标签的代码,同时将数据集分为训练集和测试集:
```python
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import os
class NUEWIDEDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = []
self.labels = []
for label in os.listdir(root_dir):
for image_file in os.listdir(os.path.join(root_dir, label)):
self.images.append(os.path.join(root_dir, label, image_file))
self.labels.append(int(label))
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = Image.open(self.images[idx]).convert('RGB')
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = NUEWIDEDataset(root_dir='path/to/dataset', transform=transform)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), len(dataset)-int(len(dataset)*0.8)])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
```
在这个示例中,我们首先定义了一个`NUEWIDEDataset`类,它继承了`Dataset`类,并实现了`__init__`、`__len__`和`__getitem__`方法。在`__init__`方法中,我们遍历数据集中的每个标签和图像文件,并将它们分别存储在`self.labels`和`self.images`中。在`__getitem__`方法中,我们打开图像文件并将其转换为`PIL.Image`对象,并返回图像和标签。
接下来,我们使用`transforms`模块定义了一个转换序列,该序列将图像大小调整为`(224, 224)`,将其转换为`Tensor`对象,并将其归一化为范围`[-1,1]`。然后,我们使用`NUEWIDEDataset`类初始化了一个数据集对象,并使用`torch.utils.data.random_split`函数将数据集分成了训练集和测试集。最后,我们使用`DataLoader`类将数据集转换为可迭代的数据加载器,并设置了每个批次的大小和是否随机打乱数据集。
阅读全文