结合改写后的工厂类,给出全部的源代码
时间: 2023-11-11 19:06:42 浏览: 27
好的,下面是改写后的工厂类的完整源代码:
```python
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class ChitGPT:
def __init__(self, model_name_or_path):
self.tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
self.model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
self.model.eval()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
def generate_response(self, input_text, max_length=100, top_p=0.92, temperature=0.85):
input_ids = self.tokenizer.encode(input_text, return_tensors='pt')
input_ids = input_ids.to(self.device)
sample_outputs = self.model.generate(
input_ids,
do_sample=True,
max_length=max_length,
top_p=top_p,
temperature=temperature,
num_return_sequences=1,
)
response = self.tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
return response
class ChitGPTFactory:
def __init__(self, model_name_or_path):
self.model_name_or_path = model_name_or_path
def create_chitgpt(self):
return ChitGPT(self.model_name_or_path)
```
其中,`ChitGPT` 类用于实现与用户进行对话的功能,而 `ChitGPTFactory` 类则用于创建 `ChitGPT` 实例。通过工厂类,我们可以在需要的时候方便地创建 `ChitGPT` 对象,而无需每次手动创建并初始化对象。