import torch from transformer import CutTaskModel,PositionalEncoding model=CutTaskModel() model= torch.load('./teach_transformer/teach_transformer.pt') src = torch.LongTensor([[0, 4, 3, 4, 6, 8, 9, 9, 8, 1, 2, 2]]) # tgt从<bos>开始,看看能不能重新输出src中的值 tgt = torch.LongTensor([[0]]) # 一个一个词预测,直到预测为<eos>,或者达到句子最大长度 for i in range(20): # 进行transformer计算 out = model(src, tgt) # 预测结果,因为只需要看最后一个词,所以取`out[:, -1]` predict = model.predictor(out[:, -1]) # 找出最大值的index y = torch.argmax(predict, dim=1) # 和之前的预测结果拼接到一起 tgt = torch.concat([tgt, y.unsqueeze(0)], dim=1) # 如果为<eos>,说明预测结束,跳出循环 if y == 1: break print(tgt) 这是源代码
时间: 2024-04-08 16:31:11 浏览: 193
transformer_torch.zip
5星 · 资源好评率100%
在你的代码中,出现了一个错误。错误发生在 `tgt = torch.concat([tgt, y.unsqueeze(0)], dim=1)` 这一行。正确的方法是使用 `torch.cat` 而不是 `torch.concat`。
请将这一行代码修改为:`tgt = torch.cat([tgt, y.unsqueeze(0)], dim=1)`。
修改后的代码如下所示:
```python
import torch
from transformer import CutTaskModel, PositionalEncoding
model = CutTaskModel()
model = torch.load('./teach_transformer/teach_transformer.pt')
src = torch.LongTensor([[0, 4, 3, 4, 6, 8, 9, 9, 8, 1, 2, 2]])
tgt = torch.LongTensor([[0]])
for i in range(20):
out = model(src, tgt)
predict = model.predictor(out[:, -1])
y = torch.argmax(predict, dim=1)
tgt = torch.cat([tgt, y.unsqueeze(0)], dim=1)
if y == 1:
break
print(tgt)
```
这样修改后,应该能够解决你遇到的问题。
阅读全文