return scatter_max(src, index, dim, out, dim_size)[0] RuntimeError: expected scalar type Long but found Float
时间: 2023-05-18 18:06:20 浏览: 76
这个错误提示意味着你的代码期望输入的数据类型是 Long,但实际输入的是 Float 类型。你需要检查你的代码,找到这个问题的根源并进行修复。可能需要将输入数据类型转换为 Long 类型,或者修改代码中的其他部分以适应 Float 类型的输入。
相关问题
Traceback (most recent call last): File "training.py", line 96, in <module> train(model, device, train_loader, optimizer, epoch+1) File "training.py", line 28, in train output = model(data) File "/home/qcui/anaconda3/envs/geometric/lib/python3.8/site-packages/torch/nn/modules/module.py", line 532, in __call__ result = self.forward(*input, **kwargs) File "/data/qcui/DTA/GraphDTA/ecfp_maccs_graph/CPI_prediction/models/gcn.py", line 64, in forward x = gmp(x, batch) # global max pooling File "/home/qcui/anaconda3/envs/geometric/lib/python3.8/site-packages/torch_geometric/nn/glob/glob.py", line 76, in global_max_pool return scatter(x, batch, dim=0, dim_size=size, reduce='max') File "/home/qcui/anaconda3/envs/geometric/lib/python3.8/site-packages/torch_scatter/scatter.py", line 155, in scatter return scatter_max(src, index, dim, out, dim_size)[0] RuntimeError: expected scalar type Long but found Float The above operation failed in interpreter. Traceback (most recent call last): File "/home/qcui/anaconda3/envs/geometric/lib/python3.8/site-packages/torch_scatter/scatter.py", line 69 dim_size: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
这是一个Python程序的错误信息,其中包含了一些文件名和行号。具体来说,这个错误信息显示了在训练模型时出现了错误,错误发生在第96行,train函数中的第28行调用了模型,但是模型的输入出现了问题,导致程序无法继续执行。具体的错误信息是在gcn.py文件的第64行,模型在进行计算时调用了gmp函数,但是输入的数据和batch信息有问题。
RuntimeError: scatter(): Expected dtype int64 for index
This error occurs when the index used for the scatter operation is not of type int64.
Possible solutions:
1. Convert the index to int64 using the `astype()` method.
```
index = index.astype('int64')
```
2. Ensure that the index is already of type int64 before performing the scatter operation.
```
index = torch.tensor([1, 2, 3], dtype=torch.int64)
```
3. Check that the index values are within the range of the tensor being scattered. If the index is out of range, the scatter operation will fail.