将上述代码整合成一个完整代码
时间: 2024-05-10 08:15:41 浏览: 173
Java用Zip压缩多个文件实例源码.7z
以下是整合后的完整代码:
```python
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class ChitChat:
def __init__(self):
self.tokenizer = GPT2Tokenizer.from_pretrained('microsoft/DialoGPT-medium')
self.model = GPT2LMHeadModel.from_pretrained('microsoft/DialoGPT-medium')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
self.chat_history_ids = None
def generate(self, text):
new_input_ids = self.tokenizer.encode(text + self.tokenizer.eos_token, return_tensors='pt')
bot_input_ids = torch.cat([self.chat_history_ids, new_input_ids], dim=-1) if self.chat_history_ids is not None else new_input_ids
self.chat_history_ids = self.model.generate(bot_input_ids, max_length=1000, pad_token_id=self.tokenizer.eos_token_id, temperature=0.7, no_repeat_ngram_size=3)
response = self.tokenizer.decode(self.chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
return response
```
使用:
```python
chatbot = ChitChat()
while True:
user_input = input("User: ")
response = chatbot.generate(user_input)
print("Bot: " + response)
```
请确保已经安装了 transformers 库。
阅读全文