pytorch中的__getitem__
时间: 2023-12-17 12:05:11 浏览: 87
Python类中方法getitem和getattr详解
`__getitem__`是Python中的一个特殊方法,用于获取对象中指定索引或键对应的值。在PyTorch中,`__getitem__`方法被用于数据集类中,用于获取指定索引的数据样本。具体来说,当我们使用PyTorch中的`DataLoader`加载数据时,`DataLoader`会调用数据集类中的`__getitem__`方法来获取指定索引的数据样本。下面是一个简单的例子:
```python
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
for batch in dataloader:
print(batch)
```
在上面的例子中,我们定义了一个`MyDataset`类,其中`__getitem__`方法返回指定索引的数据样本。然后我们创建了一个数据集对象`dataset`,并将其传递给`DataLoader`中。`DataLoader`会自动调用`dataset`中的`__getitem__`方法来获取数据样本。最后我们遍历`DataLoader`,每次输出两个数据样本。
阅读全文