def build_dataset(config, ues_word): if ues_word: tokenizer = lambda x: x.split(' ') # 以空格隔开,word-level else: tokenizer = lambda x: [y for y in x] # char-level if os.path.exists(config.vocab_path): vocab = pkl.load(open(config.vocab_path, 'rb')) else: vocab = build_vocab(config.train_path, tokenizer=tokenizer, max_size=MAX_VOCAB_SIZE, min_freq=1) pkl.dump(vocab, open(config.vocab_path, 'wb')) print(f"Vocab size: {len(vocab)}") def load_dataset(path, pad_size=32): contents = [] with open(path, 'r', encoding='UTF-8') as f: for line in tqdm(f): lin = line.strip() if not lin: continue content, label = lin.split('\t') words_line = [] token = tokenizer(content) seq_len = len(token) if pad_size: if len(token) < pad_size: token.extend([PAD] * (pad_size - len(token))) else: token = token[:pad_size] seq_len = pad_size # word to id for word in token: words_line.append(vocab.get(word, vocab.get(UNK))) contents.append((words_line, int(label), seq_len)) return contents # [([...], 0), ([...], 1), ...] train = load_dataset(config.train_path, config.pad_size) dev = load_dataset(config.dev_path, config.pad_size) test = load_dataset(config.test_path, config.pad_size) return vocab, train, dev, test
时间: 2024-01-12 20:04:45 浏览: 94
hand_dataset.tar.gz
这段代码定义了一个函数用于加载数据集。它的输入参数包括一个配置对象和一个布尔值,表示是否使用词级别的分词器。如果使用单词级别的分词器,就将句子按照空格分割成单词;否则,将句子分割成单个字符。如果已经存在词汇表文件,就直接加载该文件;否则,就调用之前定义的 build_vocab 函数构建词汇表,并将其保存到文件中。然后,函数分别加载训练、验证和测试数据集,并将每个样本表示成一个三元组,其中第一个元素是由单词索引构成的列表,第二个元素是标签,第三个元素是该样本的序列长度。最后,该函数返回词汇表和三个数据集。
阅读全文