使用pytorch替代sklearn中的euclidean_distances
时间: 2023-09-13 14:13:50 浏览: 42
可以使用PyTorch中的torch.cdist函数来替代sklearn中的euclidean_distances函数。torch.cdist函数可以计算两个张量之间所有点之间的距离,并返回一个新的张量。
以下是一个示例代码:
```python
import torch
# 生成两个张量
x = torch.randn(10, 3)
y = torch.randn(5, 3)
# 计算x和y之间的欧氏距离
distances = torch.cdist(x, y)
print(distances)
```
输出:
```
tensor([[1.9769, 1.2909, 1.9011, 1.5367, 2.6600],
[1.2516, 1.7855, 2.2966, 1.5408, 1.8056],
[1.9801, 1.2029, 1.2994, 0.6682, 2.7724],
[1.6420, 2.3039, 2.0901, 1.7433, 1.0737],
[0.6109, 1.7528, 1.3279, 1.6867, 1.1562],
[2.0429, 1.3588, 1.7649, 1.3937, 2.3472],
[2.3322, 2.9492, 2.9717, 2.7183, 1.7593],
[1.2340, 0.8943, 1.0644, 0.3252, 2.3668],
[1.5620, 1.6265, 1.7199, 1.2812, 1.6146],
[1.8884, 1.2714, 1.9464, 1.5634, 2.6892]])
```
以上代码生成了两个张量x和y,分别表示10个点和5个点在三维空间中的坐标。然后使用torch.cdist函数计算了x和y之间的欧氏距离,并将结果保存在distances张量中。最后输出distances张量,即可得到所有点之间的距离矩阵。