torch.searchsorted
时间: 2023-10-14 20:13:29 浏览: 36
torch.searchsorted is a function in PyTorch that finds the indices where elements should be inserted to maintain order in a sorted tensor. It takes two inputs: the sorted tensor and the values to be inserted. The output is a tensor of the same shape as the values tensor, containing the indices where each value should be inserted. If a value is already present in the sorted tensor, the index of the element just before it is returned. If a value is greater than all elements in the sorted tensor, the index of the last element is returned.
Here is an example usage:
```
import torch
sorted_tensor = torch.tensor([1, 2, 3, 5, 7, 8])
values = torch.tensor([2, 4, 6, 9])
indices = torch.searchsorted(sorted_tensor, values)
print(indices) # Output: tensor([1, 3, 4, 6])
```
In this example, the indices where the values [2, 4, 6, 9] should be inserted in the sorted tensor [1, 2, 3, 5, 7, 8] are [1, 3, 4, 6].