train_data = MyDataset(txt='train_list.txt', transform=transforms.ToTensor()) AttributeError: 'list' object has no attribute 'ToTensor'
时间: 2023-06-07 20:11:40 浏览: 187
train_data.txt
这个错误提示很明显,是因为在传入数据集时将一个列表传入了 transform 参数,而列表对象没有 ToTensor() 属性。正确的做法是将数据集传入 DataReader 中,然后对数据集进行 ToTensor() 处理。代码示例:
```
import torch.utils.data as Data
import torchvision.transforms as transforms
class DataReader():
def __init__(self, txt_path):
self.txt_path = txt_path
def get_dataset(self):
with open(self.txt_path) as f:
lines = f.readlines()
dataset = []
for line in lines:
line = line.strip().split()
img_path, label = line[0], line[1]
dataset.append((img_path, label))
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
dataset = MyDataset(dataset, transform=transform)
return dataset
txt_path = "train_list.txt"
train_data = Data.DataLoader(
dataset=DataReader(txt_path).get_dataset(),
batch_size=32,
shuffle=True,
drop_last=True,
num_workers=2,
)
```
阅读全文