如果数据为一个二维数组,第一行为数据,第二行为时间,该如何去修改上述代码
时间: 2024-03-02 10:52:00 浏览: 143
我的代码数据结构求修改
如果数据为一个二维数组,第一行为数据,第二行为时间,你需要对代码进行如下修改:
```python
import torch
import torch.nn.functional as F
from transformers import GPT2Tokenizer, GPT2LMHeadModel
class ChitChat:
def __init__(self):
self.tokenizer = GPT2Tokenizer.from_pretrained('microsoft/DialoGPT-medium')
self.model = GPT2LMHeadModel.from_pretrained('microsoft/DialoGPT-medium')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
self.chat_history = []
def get_response(self, user_input):
input_ids = self.tokenizer.encode(user_input + self.tokenizer.eos_token, return_tensors='pt').to(self.device)
chat_history_ids = self.model.generate(
input_ids=input_ids,
max_length=1000,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=3,
do_sample=True,
top_k=10,
top_p=0.7,
temperature=0.8,
)
response = self.tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
self.chat_history.append((user_input, response))
return response
```
在修改后的代码中,`get_response` 函数的输入参数为用户的输入,第一行数据和第二行时间都不需要作为参数传递。如果你需要在 `ChitChat` 类中保存时间信息,你可以在 `chat_history` 列表中保存每次用户输入的时间戳,例如:
```python
import time
class ChitChat:
def __init__(self):
self.tokenizer = GPT2Tokenizer.from_pretrained('microsoft/DialoGPT-medium')
self.model = GPT2LMHeadModel.from_pretrained('microsoft/DialoGPT-medium')
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
self.model.eval()
self.chat_history = []
def get_response(self, user_input):
input_ids = self.tokenizer.encode(user_input + self.tokenizer.eos_token, return_tensors='pt').to(self.device)
chat_history_ids = self.model.generate(
input_ids=input_ids,
max_length=1000,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=3,
do_sample=True,
top_k=10,
top_p=0.7,
temperature=0.8,
)
response = self.tokenizer.decode(chat_history_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
self.chat_history.append((user_input, time.time(), response))
return response
```
在上述代码中,`chat_history` 列表中每个元素包含三个值:用户输入、时间戳和机器人回复。
阅读全文