torch.gather
时间: 2023-11-08 16:01:18 浏览: 150
torch.gather是一个PyTorch中的函数,它用于按照给定的索引列表从输入张量中选择指定的元素。在给定的维度上,它根据索引列表index从输入张量input中选取元素,并返回一个新的张量。
具体来说,对于一个3D张量,输出的元素由如下公式指定:
out[i][j][k] = input[index[i][j][k]][j][k],当dim=0;
out[i][j][k] = input[i][index[i][j][k]][k],当dim=1;
out[i][j][k] = input[i][j][index[i][j][k]],当dim=2。
例如,对于以下代码:
b = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_1))
print (torch.gather(b, dim=0, index=index_2))
输出结果为:
tensor([[1, 2],
[6, 4]])
tensor([[1, 5, 5],
[1, 2, 3]])
这意味着在第一个例子中,我们选择了索引为(0,0)的元素1,索引为(1,1)的元素5,并将它们放置在新的张量中。在第二个例子中,我们选择了索引为(0,0)的元素1,并将它放置在新的张量中的第一行,选择了索引为(1,0)的元素4,并将它放置在新的张量中的第二行,然后将索引为(1,1)和(1,2)的元素都选择为5,并将它们放置在新的张量中的第三行。
阅读全文