解释一下transform_list = [ transforms.ToTensor() ]
时间: 2023-06-08 10:08:16 浏览: 150
transform_list是一个包含多个图像转换操作的列表。transforms.ToTensor()是其中的一个操作,它将图像数据从PIL Image对象转换为torch张量。这个操作还会将数据缩放到0到1之间,并且将通道顺序从HWC(Hight,Width,Channel)转换为CHW(Channel,Hight,Width)。
相关问题
train_data = MyDataset(txt='train_list.txt', transform=transforms.ToTensor()) AttributeError: 'list' object has no attribute 'ToTensor'
这个错误提示很明显,是因为在传入数据集时将一个列表传入了 transform 参数,而列表对象没有 ToTensor() 属性。正确的做法是将数据集传入 DataReader 中,然后对数据集进行 ToTensor() 处理。代码示例:
```
import torch.utils.data as Data
import torchvision.transforms as transforms
class DataReader():
def __init__(self, txt_path):
self.txt_path = txt_path
def get_dataset(self):
with open(self.txt_path) as f:
lines = f.readlines()
dataset = []
for line in lines:
line = line.strip().split()
img_path, label = line[0], line[1]
dataset.append((img_path, label))
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
dataset = MyDataset(dataset, transform=transform)
return dataset
txt_path = "train_list.txt"
train_data = Data.DataLoader(
dataset=DataReader(txt_path).get_dataset(),
batch_size=32,
shuffle=True,
drop_last=True,
num_workers=2,
)
```
class TestDataset(Dataset): def __init__(self, folder_path, hr, lr): super(TestDataset, self).__init__() self.folder_path = folder_path self.hr = hr self.lr = lr self.hr_path = os.path.join(self.folder_path, self.hr) self.lr_path = os.path.join(self.folder_path, self.lr) self.hr_list = [entry.name for entry in os.scandir(self.hr_path) if entry.is_file()] self.lr_list = [entry.name for entry in os.scandir(self.lr_path) if entry.is_file()] self.filenames = self.hr_list self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def __getitem__(self, index): self.hr_name = self.hr_list[index] self.lr_name = self.lr_list[index] self.hr_item_path = os.path.join(self.hr_path, self.hr_name) self.lr_item_path = os.path.join(self.lr_path, self.lr_name) hr_img = Image.open(self.hr_item_path).convert("RGB") lr_img = Image.open(self.lr_item_path).convert("RGB") hr_img = self.transform(hr_img) lr_img = self.transform(lr_img) return lr_img, hr_img def __len__(self): return len(self.filenames)
这是一个 Python 代码段,用于定义一个名为 TestDataset 的数据集类。该类接受三个参数:文件夹路径、高分辨率图像文件夹名称和低分辨率图像文件夹名称。它通过扫描指定文件夹中的文件来获取高分辨率和低分辨率图像的文件名列表,并将其存储在 hr_list 和 lr_list 中。在 __getitem__ 方法中,它会打开指定文件夹中的图像文件,并将其转换为张量,并返回低分辨率和高分辨率图像的张量。在 __len__ 方法中,它返回文件名列表的长度。
阅读全文