我想做一个知识问答的模型,请为我写一段数据加载的代码,数据分为<question,answer,lable>三个元组,如果answer是question的正确答案,则lable为1
时间: 2024-02-03 09:14:40 浏览: 156
好的,这里给您提供一个简单的数据加载代码示例,假设您的数据存储在一个CSV文件中,每一行为一个<question, answer, label>的元组:
```python
import csv
from torch.utils.data import Dataset
class QADataset(Dataset):
def __init__(self, file_path):
self.data = []
with open(file_path, 'r', encoding='utf-8') as f:
reader = csv.reader(f, delimiter=',')
for row in reader:
question, answer, label = row[0], row[1], int(row[2])
self.data.append((question, answer, label))
def __len__(self):
return len(self.data)
def __getitem__(self, index):
question, answer, label = self.data[index]
return question, answer, label
```
其中,我们首先定义了一个`QADataset`类,继承自PyTorch中的`Dataset`类。在构造函数中,我们使用Python内置的`csv`模块读取CSV文件,并将每一行转换成一个元组`(question, answer, label)`,并保存在`self.data`中。在`__len__`方法中,我们实现了获取数据集大小的功能,返回`self.data`的长度即可。在`__getitem__`方法中,我们实现了获取指定索引的数据样本的功能,返回该索引对应的元组即可。这样,我们就可以使用PyTorch中的`DataLoader`类来对数据进行批量处理了。
阅读全文