_Q = torch.gather(_Qs, 1, _action) #摇摆棍 RuntimeError: gather(): Expected dtype int64 for index
时间: 2024-02-25 16:59:06 浏览: 118
torch_scatter-2.1.2-cp310-cp310-macosx_11_0_x86_64.whl.zip
这个错误通常是因为 `gather()` 函数的第二个参数 `_action` 的数据类型不是 `torch.int64`。`gather()` 函数需要整数类型的索引来从 `_Qs` 张量中收集数据。你可以尝试将 `_action` 转换为 `torch.int64` 类型,例如:
```python
_Q = torch.gather(_Qs, 1, _action.long())
```
如果 `_action` 的数据类型已经是整数类型,你也可以尝试将其转换为 `torch.int64`,例如:
```python
_Q = torch.gather(_Qs, 1, _action.to(torch.int64))
```
如果你需要进一步帮助,请提供更多的上下文和代码。
阅读全文