如何使用collate_fn
时间: 2023-08-29 08:08:54 浏览: 101
使用`collate_fn`是在训练神经网络时,对输入的批次数据进行预处理和整理的一个函数。它将原始的样本组成的列表转化为可以输入到神经网络的张量。
下面是一个示例代码,展示了如何使用`collate_fn`:
```python
import torch
from torch.utils.data import DataLoader
# 假设你有一个自定义的数据集类 MyDataset
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 假设你的训练数据是一个列表,每个元素是一个样本
train_data = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6])]
# 创建自定义数据集对象
dataset = MyDataset(train_data)
# 创建一个数据加载器,并指定 collate_fn 参数
dataloader = DataLoader(dataset, batch_size=2, collate_fn=torch.stack)
# 遍历数据加载器
for batch in dataloader:
print(batch)
```
在上述代码中,我们首先定义了一个自定义的数据集类`MyDataset`,然后创建了一个训练数据列表`train_data`。接着,我们使用`MyDataset`类实例化了一个数据集对象`dataset`。
然后,我们创建了一个数据加载器`dataloader`,其中指定了`batch_size`参数为2,并且将`collate_fn`参数设置为`torch.stack`函数。`torch.stack`函数用于将列表中的张量堆叠成一个张量。
最后,我们遍历数据加载器,每次迭代得到一个批次的数据。在这个例子中,输出将会是两个样本的张量组成的批次。你可以根据自己的需求定义并使用不同的`collate_fn`函数来处理不同类型的数据。
阅读全文