tensor.gather
时间: 2024-05-15 09:17:05 浏览: 113
The `tensor.gather` function is a method in PyTorch that allows you to gather values from a tensor along a specified axis. It takes two arguments - the `dim` parameter which specifies the axis along which to gather values, and an `index` tensor which specifies the indices of the values to be gathered.
For example, consider a 2D tensor `x`:
```
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
```
To gather elements along the second axis (columns), you can use the following code:
```
indices = torch.tensor([1, 2, 0])
gathered_values = x.gather(1, indices.unsqueeze(1))
```
Here, `indices` specifies the column indices of the values to be gathered. The `unsqueeze(1)` call is used to add a new dimension to the `indices` tensor, so that it can be broadcasted with the `x` tensor. The resulting `gathered_values` tensor would be:
```
tensor([[2],
[6],
[7]])
```
This is equivalent to selecting the second element from the first row, the third element from the second row, and the first element from the third row.
阅读全文