return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) RuntimeError: Expected floating point type for target with class probabilities, got Long
时间: 2024-01-22 17:19:57 浏览: 329
seq2seq_tutorial_torch.rar_Torch tutorial_seq2seq_torch_英文到数字的转化
这个错误通常是因为你的 `target` 张量的数据类型是 `Long` (整数类型),而交叉熵损失函数需要接受一个浮点数类型的 `target` 张量,该张量应该是类别概率。为了解决这个问题,你可以将 `target` 张量转换为浮点数类型,可以使用 `torch.float()` 方法将 `target` 张量转换为浮点数类型。例如:
```python
target = target.float()
```
你也可以在加载数据时将 `target` 张量转换为浮点数类型,这样你就不需要在每次训练时转换了。
阅读全文