将上述代码写成一个通用得工具类
时间: 2024-03-06 07:49:43 浏览: 28
好的,可以先定义一个 Python 类,可以接收输入的文本,然后使用预训练的 GPT 模型生成 AI 回复,最后将回复输出。以下是一个简单的实现:
```python
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
class GPTChatbot:
def __init__(self):
self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.model = GPT2LMHeadModel.from_pretrained('gpt2')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
def generate_response(self, input_text):
input_ids = self.tokenizer.encode(input_text, return_tensors='pt').to(self.device)
sample_output = self.model.generate(
input_ids,
do_sample=True,
max_length=50,
top_k=50,
top_p=0.95,
num_return_sequences=1
)
response = self.tokenizer.decode(sample_output[0], skip_special_tokens=True)
return response
```
这个类中,我们首先使用 `GPT2Tokenizer` 和 `GPT2LMHeadModel` 类来加载 GPT 模型和分词器。然后,我们可以使用 `generate_response` 方法接收用户输入,并使用 GPT 模型生成 AI 回复。在 `generate_response` 方法中,我们先将输入文本编码为 token ID,并使用 `model.generate()` 方法生成 AI 回复。回复以 token ID 的形式返回,我们可以使用分词器将其转换为文本,并返回 AI 回复。
你可以使用以下代码测试这个工具类:
```python
chatbot = GPTChatbot()
while True:
user_input = input('User: ')
response = chatbot.generate_response(user_input)
print('AI: ' + response)
```
这个代码将启动一个命令行聊天程序,你可以不断地输入问题,然后 AI 将回答你的问题。