训练完怎么求相似度和词向量
时间: 2024-05-07 18:21:45 浏览: 11
要求相似度和词向量,需要使用预训练好的模型。对于ChitGPT模型,可以使用Hugging Face提供的transformers库来加载已经训练好的模型。然后,可以使用模型的`get_word_embedding_matrix()`方法来获取词向量矩阵,或者使用`cosine_similarity()`方法来计算两个句子之间的余弦相似度。
下面是一些示例代码:
```python
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity
# 加载tokenizer和模型
tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')
model = AutoModel.from_pretrained('microsoft/DialoGPT-small')
# 获取词向量矩阵
word_embeddings = model.get_input_embeddings().weight.detach().numpy()
# 计算两个句子之间的余弦相似度
input_text1 = "Hello, how are you?"
input_text2 = "Hi, how are you doing?"
input_tokens1 = tokenizer(input_text1, return_tensors='pt')['input_ids']
input_tokens2 = tokenizer(input_text2, return_tensors='pt')['input_ids']
output1 = model(input_tokens1)[0][:, -1, :]
output2 = model(input_tokens2)[0][:, -1, :]
similarity = cosine_similarity(output1.detach().numpy(), output2.detach().numpy())[0][0]
print(similarity)
```
注意,这里使用的是Microsoft的DialoGPT-small模型,而不是CSDN开发的ChitGPT模型。如果你想使用ChitGPT模型,只需要将模型名称和tokenizer名称替换为对应的名称即可。
相关推荐
![pdf](https://img-home.csdnimg.cn/images/20210720083512.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)