input_ids = [int(vocab[i]) if i in vocab else int(vocab['[UNK]']) for i in tokens_f]代码报错KeyError: '[UNK]'怎么解决
时间: 2023-10-09 20:17:29 浏览: 46
这个错误通常是因为词汇表(vocab)中没有'[UNK]'这个特殊符号导致的。解决方法是先检查词汇表中是否有'[UNK]'这个特殊符号,如果没有,可以手动添加。如果词汇表中已经有了'[UNK]'这个特殊符号,那么可能是tokens_f中的某个token没有被正确地分割或处理,导致它无法匹配到词汇表中的任何一个token,从而触发了这个错误。此时可以检查tokens_f中的每个token是否正确,或者尝试使用其他的分词方法处理文本数据。
相关问题
for _ in range(generate_max_len): outputs = model(input_ids=curr_input_tensor) next_token_logits = outputs[0][-1, :] #size:[vocab size] # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率 for id_ in set(generated): next_token_logits[id_] /= repetition_penalty # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=topk, top_p=topp) # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) if next_token.item() == tokenizer.sep_token_id: # 遇到[SEP]则表明response生成结束 break generated.append(next_token.item()) curr_input_tensor = torch.cat((curr_input_tensor, next_token), dim=0) text = tokenizer.convert_ids_to_tokens(generated) print("summary:" + "".join(text))这段什么意思
这段代码是用于基于预训练的语言模型生成文本的过程。它使用了top-k和top-p采样的技术,以及一个重复惩罚项,来生成一个在给定前缀下的文本摘要。其中,top-k采样是指从预测的概率分布中选择前k个最高概率的token,top-p采样是指从预测的概率分布中选择累计概率超过一个阈值p的所有token。这样可以增加生成文本的多样性。重复惩罚项是为了防止模型重复生成相同的token。具体来说,对于已经生成的token,它们在下一次生成时的概率会进行一定的惩罚,以降低它们被重复生成的概率。另外,对于[UNK]这个token,模型的预测结果不可能是它,因此它的概率被设为无穷小。最终,生成的文本在遇到[SEP]这个token时结束,输出生成的文本摘要。
请将下列shell语句转换为python语句:if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then # Make train dict echo "Make a dictionary" mkdir -p $(dirname $dict) echo "<blank> 0" > ${dict} # 0 will be used for "blank" in CTC echo "<unk> 1" >> ${dict} # <unk> must be 1 if [ "${token_type}" = bpe ]; then # we borrowed these code and scripts which are related bpe from ESPnet. cut -f 2- -d" " data/${train_set}/text | sort > data/lang_char/input.txt tools/spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 tools/spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+1}' >> ${dict} elif [ "${token_type}" = char ]; then tools/text2token.py -s 1 -n 1 data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0 " " NR+1}' >> ${dict} fi num_token=$(cat $dict | wc -l) echo "<sos/eos> $num_token" >> $dict # <eos> fi
以下是将上述shell语句转换为Python语句的代码:
```python
if stage <= 2 and stop_stage >= 2:
# Make train dict
print("Make a dictionary")
os.makedirs(os.path.dirname(dict), exist_ok=True)
with open(dict, "w") as f:
f.write("<blank> 0\n") # 0 will be used for "blank" in CTC
f.write("<unk> 1\n") # <unk> must be 1
if token_type == "bpe":
# we borrowed these code and scripts which are related bpe from ESPnet.
with open(f"data/{train_set}/text") as text_file:
texts = [line.strip().split()[1:] for line in text_file]
with open("data/lang_char/input.txt", "w") as input_file:
for text in texts:
input_file.write(" ".join(text) + "\n")
subprocess.run(f"tools/spm_train --input=data/lang_char/input.txt --vocab_size={nbpe} --model_type={bpemode} --model_prefix={bpemodel} --input_sentence_size=100000000", shell=True)
with open(f"{bpemodel}.vocab") as vocab_file:
vocab = [line.strip().split()[0] for line in vocab_file]
with open(dict, "a") as dict_file:
for i, vocab_item in enumerate(vocab):
dict_file.write(f"{vocab_item} {i+2}\n")
elif token_type == "char":
with open(f"data/{train_set}/text") as text_file:
texts = [line.strip().split()[1] for line in text_file]
texts = " ".join(texts)
tokens = subprocess.check_output(f"echo '{texts}' | tools/spm_encode --model={bpemodel}.model --output_format=piece", shell=True)
tokens = tokens.decode().split("\n")
tokens = [token.strip() for token in tokens if token.strip() != ""]
tokens = sorted(set(tokens))
with open(dict, "a") as dict_file:
for i, token in enumerate(tokens):
dict_file.write(f"{token} {i+2}\n")
num_token = sum(1 for line in open(dict))
with open(dict, "a") as dict_file:
dict_file.write(f"<sos/eos> {num_token}\n") # <eos>
```
需要注意的是,在Python中也需要使用相应的模块和工具来实现shell语句中的一些操作,例如在此代码中使用了`os`、`subprocess`模块和`tools/spm_train`、`tools/spm_encode`脚本来进行文件操作和调用外部命令。另外,由于Python中没有直接对应的`$`符号,需要使用`f-string`或者`str.format()`方法来进行字符串格式化。
阅读全文