写一个载入文字的pytorch dataloader数据载入代码
时间: 2024-03-26 13:34:15 浏览: 91
好的,以下是一个基本的PyTorch DataLoader数据载入代码,用于载入文本数据:
```python
import torch
from torch.utils.data import Dataset, DataLoader
class TextDataset(Dataset):
def __init__(self, text_path):
# 读取文本文件
with open(text_path, 'r', encoding='utf-8') as f:
self.text = f.read()
# 将文本分割成句子
self.sentences = self.text.split('\n')
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
# 对每个句子进行处理
sentence = self.sentences[idx]
# 将句子转化为数字表示,这里可以使用自己的分词方法
tokens = [int(token) for token in sentence.split()]
# 返回数字表示的句子
return torch.tensor(tokens)
# 定义数据载入器
def get_dataloader(text_path, batch_size):
dataset = TextDataset(text_path)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataloader
```
这个代码中,我们定义了一个TextDataset类,用于读取文本文件并将文本分割成句子。在getitem方法中,我们将每个句子转化为数字表示,并返回数字表示的句子。
然后,我们定义了一个get_dataloader函数,用于创建数据载入器。在这个函数中,我们创建了一个TextDataset实例,并将其传入DataLoader中,设置了batch_size和shuffle等参数。最后返回一个dataloader对象。
你可以根据自己的需求对这个代码进行修改,例如添加Padding等步骤。
阅读全文