请帮助我用pytroch写一个批量读取文本的类,该类继承torch.nn.Moudle
时间: 2023-02-18 19:26:35 浏览: 115
可以参考下面的代码:
class BatchReader(torch.nn.Module):
def __init__(self, text_path):
super(BatchReader, self).__init__()
self.text_path = text_path
self.text_file = open(text_path, 'r')
self.batch_size = 0
self.batch_data = []
def __len__(self):
return self.batch_size
def __iter__(self):
return self
def __next__(self):
if self.batch_size == 0:
self.batch_data = self.text_file.readlines()
self.batch_size = len(self.batch_data)
if self.batch_size == 0:
raise StopIteration
batch_data = self.batch_data[:self.batch_size]
self.batch_data = self.batch_data[self.batch_size:]
self.batch_size = len(self.batch_data)
return batch_data
阅读全文