我想直接得到这些向量,不只是索引
时间: 2023-08-02 21:39:38 浏览: 37
如果您想直接得到A中每个向量在B中距离最远的向量,可以使用`torch.index_select()`函数结合`torch.argmax()`函数来实现。
具体而言,您可以使用以下代码:
```python
import torch
A = torch.randn(32, 256)
B = torch.randn(32, 256)
distances = torch.cdist(A, B, p=2.0)
farthest_indices = torch.argmax(distances, dim=1)
farthest_vectors = torch.index_select(B, dim=0, index=farthest_indices)
```
其中,`distances`是由`torch.cdist()`函数计算得到的形状为`(32, 32)`的张量,表示A中每个向量和B中每个向量之间的欧氏距离。`farthest_indices`是由`torch.argmax()`函数计算得到的形状为`(32,)`的张量,表示A中每个向量在B中距离最远的向量的索引。`farthest_vectors`是由`torch.index_select()`函数计算得到的形状为`(32, 256)`的张量,表示A中每个向量在B中距离最远的向量。
需要注意的是,`torch.index_select()`函数的第一个参数是要进行索引的张量,这里是B;第二个参数是要进行索引的维度,这里是第0维,即按行索引;第三个参数是要进行索引的索引值,这里是`farthest_indices`。
同样,如果您只需要计算A中每个向量在B中距离最近的向量,可以将`torch.argmax()`替换为`torch.argmin()`即可。