def _pred(sentence, temperature=1): if len(sentence) < max_len: print('in def _pred,length error ') return sentence = sentence[-max_len:] x_pred = np.zeros((1, max_len, len(words))) for t, char in enumerate(sentence): x_pred[0, t, word2numF(char)] = 1. preds = model.predict(x_pred, verbose=0)[0] preds = np.asarray(preds).astype('float64') exp_preds = np.power(preds, temperature) # 计算所有备选输出文字概率`preds`的`temperature`次方 preds = exp_preds / np.sum(exp_preds) # 重新统计概率分布 pro = np.random.choice(range(len(preds)), 1, p=preds) # 根据新概率随机选择候选文字 next_index = int(pro.squeeze()) next_char = num2word[next_index] return next_char
时间: 2023-11-28 14:06:40 浏览: 59
这段代码是一个用于生成文本的函数 `_pred`,它接受一个字符串 `sentence` 和一个温度参数 `temperature` 作为输入,返回一个字符作为输出。这个函数的实现基于一个训练好的模型,并利用该模型对给定的 `sentence` 进行预测,得到下一个字符的概率分布。然后,使用温度参数 `temperature` 对概率分布进行调节,加入一定的随机性,最后根据新的概率分布随机选择一个候选字符作为输出,并将其返回。
相关问题
import ast from dataclasses import dataclass from typing import List import pandas as pd import json ["text", "六十一岁还能办什么保险"] @dataclass class FAQ: title: str sim_questions: List[str] answer: str faq_id: int ori_data = pd.read_csv('baoxianzhidao_filter.csv') data = [] exist_titles = set() for index, row in enumerate(ori_data.iterrows()): row_dict = row[1] title = row_dict['title'] if title not in exist_titles: data.append(FAQ(title=title, answer=row_dict['reply'], sim_questions=[title], faq_id=index)) exist_titles.add(title) from modelscope.pipelines import pipeline from modelscope.utils.constant import Tasks pipeline_ins = pipeline(Tasks.faq_question_answering, 'damo/nlp_mgimn_faq-question-answering_chinese-base') bsz = 32 all_sentence_vecs = [] batch = [] sentence_list = [faq.title for faq in data] for i,sent in enumerate(sentence_list): batch.append(sent) if len(batch) == bsz or (i == len(sentence_list)-1 and len(batch)>0): # if i == len(sentence_list)-1 and len(batch)>0: sentence_vecs = pipeline_ins.get_sentence_embedding(batch) all_sentence_vecs.extend(sentence_vecs) batch.clear() import faiss import numpy as np hidden_size = pipeline_ins.model.network.bert.config.hidden_size # hidden_size = pipeline_ins.model.bert.config.hidden_size index = faiss.IndexFlatIP(hidden_size) vecs = np.asarray(all_sentence_vecs, dtype='float32') index.add(vecs) from modelscope.outputs import OutputKeys def ask_faq(input, history=[]): # step1: get sentence vector of query query_vec = pipeline_ins.get_sentence_embedding([input])[0] query_vec = np.asarray(query_vec, dtype='float32').reshape([1, -1]) # step2: faq dense retrieval _, indices = index.search(query_vec, k=30) # step3: build support set support_set = [] for i in indices.tolist()[0]: faq = data[i] support_set.append({"text": faq.title, "label": faq.faq_id, "index": i}) # step4: faq ranking rst = pipeline_ins(input={"query_set": input, "support_set": support_set}) rst = rst[OutputKeys.OUTPUT][0][0] pred_label = rst['label'] pred_score = rst['score'] # get answer by faq_id pred_answer = "" pred_title = "" for faq in data: if faq.faq_id == pred_label: pred_answer = faq.answer pred_title = faq.title break history.append((f'{pred_answer}|(pred_title:{pred_title},pred_score:{pred_score:.3f})')) return history优化这段代码
这段代码是一个Python脚本,用于读取CSV文件中的保险相关问题和答案,构建一个FAQ对象(包含问题、答案、相似问题和FAQ ID),并使用modelscope库中的pipeline进行常见问题解答。其中用到了ast、dataclass、List、pandas和json等Python模块。具体实现过程请参考上述代码。
Traceback (most recent call last): File "/home/chenxingyue/codes/caopengfei/CMeKG_tools/test5.py", line 9, in <module> my_pred.predict_sentence("".join(sentence.split())) File "/home/chenxingyue/codes/caopengfei/CMeKG_tools/medical_cws.py", line 105, in predict_sentence self.model.load_state_dict(torch.load(self.NEWPATH,map_location=self.device)) File "/home/chenxingyue/anaconda3/envs/py39/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1667, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for BERT_LSTM_CRF: Missing key(s) in state_dict: "word_embeds.embeddings.position_ids".
这个错误是因为在加载模型的state_dict时,缺少了键"word_embeds.embeddings.position_ids"。这可能是因为你的模型和保存的state_dict之间存在不匹配的情况。你可以尝试以下解决方法:
1. 确保你的模型和保存的state_dict是兼容的,即它们具有相同的模型结构和参数数量。
2. 检查模型的定义,确保没有更改模型的结构或参数名称。
3. 如果你有一个以前版本的state_dict,尝试使用该版本进行加载,或者尝试使用其他可用的预训练模型进行加载。
4. 如果你有训练好的模型权重文件,可以尝试重新训练模型并保存新的state_dict。
希望这些解决方法能够帮助你解决问题!如果还有其他问题,请随时提问。
阅读全文