torch.cdist
时间: 2023-10-31 21:11:10 浏览: 202
torch.cdist() is a function in PyTorch that computes the pairwise distance between two sets of vectors. It takes two tensors as input and returns a tensor containing the pairwise distances between the vectors in the two input sets.
The function supports various distance metrics such as Euclidean distance, Manhattan distance, and cosine distance.
Syntax:
torch.cdist(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary')
where,
- x1, x2: Two input tensors containing vectors
- p: The order of the norm used to compute the distance (default=2, i.e., Euclidean distance)
- compute_mode: The mode used to compute the distance (default='use_mm_for_euclid_dist_if_necessary', which uses matrix multiplication for Euclidean distance and a specialized implementation for other distance metrics)
Example:
```
import torch
x1 = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]])
x2 = torch.tensor([[2, 2], [2, 3], [3, 2], [3, 3]])
distances = torch.cdist(x1, x2, p=2)
print(distances)
```
Output:
```
tensor([[2.8284, 3.1623, 3.1623, 3.6056],
[2.0000, 2.2361, 2.2361, 2.8284],
[2.0000, 2.2361, 2.2361, 2.8284],
[2.8284, 3.1623, 3.1623, 3.6056]])
```
In this example, we compute the pairwise Euclidean distances between the vectors in x1 and x2. The output tensor contains the distances between each pair of vectors.
阅读全文