y = torch.tensor([0, 2]) y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]]) y_hat[[0, 1], y]怎么理解
时间: 2024-01-05 10:03:33 浏览: 78
Pytorch学习之torch用法—-比较操作(Comparison Ops)
这是一个关于PyTorch的张量索引问题。y是一个1维的张量,里面包含两个元素[0, 2],表示要取y_hat矩阵中第0行和第2行。而y_hat[[0, 1], y]的结果是取y_hat矩阵中第0行和第2行,以及每行中索引为0和索引为2的元素,也就是[0.1, 0.6]和[0.5, 0.5],返回一个1维张量。
阅读全文