torch.gather
时间: 2023-09-03 11:15:36 浏览: 78
torch.gather gathers values along an input tensor based on the indices specified in another tensor. The input tensor and the index tensor should have the same shape or broadcastable shapes.
Syntax: `torch.gather(input, dim, index, out=None)`
Parameters:
- `input` (Tensor): The input tensor to gather values from.
- `dim` (int): The dimension along which to index into the `input` tensor.
- `index` (LongTensor): The indices to gather values from. It should have the same shape as the `input` tensor except along the `dim` dimension, where it can have any shape as long as it is broadcastable with the shape of `input` along that dimension.
- `out` (Tensor, optional): The output tensor. If not None, the result will be written into this tensor. If None, a new tensor will be created.
Returns: A tensor with the same shape as `index`, where each element is the value of `input` at the corresponding index in `index` along the `dim` dimension.
Example:
```
import torch
input = torch.tensor([[1, 2], [3, 4]])
index = torch.tensor([[0, 0], [1, 0]])
output = torch.gather(input, 1, index)
print(output)
```
Output:
```
tensor([[1, 1],
[4, 3]])
```
In this example, we have an input tensor of shape (2, 2) and an index tensor of shape (2, 2). We want to gather values from the input tensor along dimension 1 based on the indices specified in the index tensor. The output tensor will have the same shape as the index tensor and will contain the gathered values.
阅读全文