(base) D:\BlackOp\Source\QuantitativeTrade\Source> (base) D:\BlackOp\Source\QuantitativeTrade\Source>python test.py Traceback (most recent call last): File "D:\BlackOp\Source\QuantitativeTrade\Source\test.py", line 90, in <module> model.TrainTransformerModule(num_epochs) File "D:\BlackOp\Source\QuantitativeTrade\Source\test.py", line 40, in TrainTransformerModule outputs = self.forward(batch_inputs) File "D:\BlackOp\Source\QuantitativeTrade\Source\test.py", line 19, in forward x = self.embedding(x) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl return forward_call(*input, **kwargs) File "D:\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py", line 160, in forward return F.embedding( File "D:\Anaconda3\lib\site-packages\torch\nn\functional.py", line 2210, in embedding return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse) RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)
时间: 2023-07-21 07:03:05 浏览: 187
这个错误是由于在调用 `torch.embedding()` 函数时,输入的索引张量类型不正确导致的。函数要求索引张量的数据类型是 `torch.LongTensor` 或者 `torch.IntTensor`,但是你传入的索引张量的数据类型是 `torch.cuda.FloatTensor`。
为了解决这个问题,你需要将索引张量的数据类型转换为正确的类型。可以使用 `torch.tensor()` 函数将数据类型转换为 `torch.LongTensor`,如下所示:
```python
x = torch.tensor(x, dtype=torch.long)
```
在你的代码中,可能需要对 `batch_inputs` 进行相应的转换,具体取决于你在哪里使用了该张量作为索引。
确保在调用 `torch.embedding()` 函数之前,将索引张量的数据类型正确转换为 `torch.LongTensor` 或者 `torch.IntTensor`,这样就可以解决这个错误。
阅读全文