计算两个二维分布的Wasserstein距离实例
时间: 2023-06-13 14:02:06 浏览: 401
假设我们有两个二维分布 $P$ 和 $Q$,我们可以使用 PyTorch 中的 POT(Python Optimal Transport)库来计算它们之间的Wasserstein距离。
以下是一个示例代码:
```
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import multivariate_normal
import ot
#生成两个二维高斯分布
mean1 = [0, 0]
cov1 = [[1, 0], [0, 1]]
P = multivariate_normal(mean1, cov1)
mean2 = [3, 3]
cov2 = [[1, 0], [0, 1]]
Q = multivariate_normal(mean2, cov2)
#生成样本
N = 1000
X = P.rvs(size=N)
Y = Q.rvs(size=N)
#计算距离矩阵
M = ot.dist(X, Y)
#计算Wasserstein距离
Wd = ot.emd2([], [], M)
print("Wasserstein距离为:", Wd)
#绘制分布和样本
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].scatter(X[:, 0], X[:, 1], c='b', alpha=0.5)
axs[0].scatter(Y[:, 0], Y[:, 1], c='r', alpha=0.5)
axs[0].set_title("分布")
axs[1].scatter(X[:, 0], X[:, 1], c='b', alpha=0.5)
axs[1].scatter(Y[:, 0], Y[:, 1], c='r', alpha=0.5)
axs[1].set_title("样本")
plt.show()
```
以上代码将生成两个二维高斯分布(蓝色和红色),并从每个分布中采样1000个样本。然后,我们使用OT库计算距离矩阵和Wasserstein距离,并将其打印出来。最后,我们绘制了分布和样本的散点图,以便更好地可视化它们之间的差异。
请注意,这里我们使用了 OT 库的`emd2`函数来计算Wasserstein距离。对于更大的数据集,可能需要使用更高级的算法来加速计算。
阅读全文