torch.gather
时间: 2023-09-03 15:15:15 浏览: 130
torch.gather is a PyTorch function that allows us to gather values along a given axis from a tensor. The function takes two input tensors:
- input tensor: the tensor from which we want to extract values.
- index tensor: the tensor containing the indices of the values we want to extract from the input tensor.
The output tensor will have the same shape as the index tensor and will contain the values extracted from the input tensor.
Example:
```
import torch
# Define input tensor
input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6]])
# Define index tensor
index_tensor = torch.tensor([[0, 1], [1, 0]])
# Gather values from input tensor
output_tensor = torch.gather(input_tensor, 1, index_tensor)
print(output_tensor)
```
Output:
```
tensor([[1, 2],
[4, 3]])
```
In this example, we have an input tensor of shape (3, 2) and an index tensor of shape (2, 2). We want to extract values from the input tensor based on the indices in the index tensor. The `torch.gather` function is called with `dim=1` (second dimension), which means we want to extract values from the second dimension of the input tensor. The resulting output tensor has the same shape as the index tensor and contains the values extracted from the input tensor.
阅读全文