RuntimeError: "nll_loss_forward_reduce_cuda_kernel_2d_index" not implemented for 'int'
时间: 2023-12-04 21:04:57 浏览: 215
JAVA_API_1.7中文.zip_JDK API_java1.7_jdk 1.7_jdk 1.7 api_jdk-7
这个错误通常出现在使用 PyTorch 进行训练时,损失函数的输入数据类型不正确。
`nll_loss_forward_reduce_cuda_kernel_2d_index` 是 PyTorch 中计算负对数似然损失的函数,它要求输入的标签数据类型为 LongTensor 类型,而你的输入数据类型是 int 类型,所以会报错。
解决方法是将标签数据类型转换为 LongTensor 类型,可以使用 `label = label.long()` 将标签转换成 LongTensor 类型。
阅读全文