y = torch.tensor([0, 2]) y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]]) y_hat[[0, 1], y]中y_hat是如何索引的
时间: 2024-01-02 10:05:19 浏览: 117
one hot编码:`torch.Tensor.scatter_()`函数用法详解
y_hat[[0, 1], y]是利用了Tensor的高级索引(advanced indexing)功能,在y_hat矩阵的第一维中选择索引为[0,1]的两个元素,而在第二维中,选择与y中对应位置的数值相等的元素。因此,y_hat[[0, 1], y]实际上是选择了y_hat矩阵中第1行的第0列元素和第2行的第2列元素,即[0.1, 0.5]。
阅读全文