plt_divergence
时间: 2023-08-30 12:11:11 浏览: 96
plt_divergence 是一个函数,用于绘制两个概率分布之间的差异或散度。它通常用于比较两个模型的性能,或者比较一个模型在不同数据集上的性能。常见的散度指标包括 KL 散度、交叉熵、JS 散度等。
函数的输入参数通常包括两个概率分布的概率密度函数或概率质量函数,以及一个可选参数,用于指定使用哪种散度指标。函数的输出通常是一个绘制好的散点图或柱状图,其中横轴表示样本空间,纵轴表示散度值。
示例代码:
```
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
def kl(p, q):
return np.sum(np.where(p != 0, p * np.log(p / q), 0))
def plt_divergence(p, q, divergence='kl'):
if divergence == 'kl':
div_func = kl
elif divergence == 'cross_entropy':
div_func = lambda p, q: -np.sum(np.where(p != 0, p * np.log(q), 0))
elif divergence == 'js':
def js(p, q):
m = (p + q) / 2
return (kl(p, m) + kl(q, m)) / 2
div_func = js
else:
raise ValueError('Unknown divergence metric')
x = np.linspace(-5, 5, 100)
p_pdf = norm.pdf(x, loc=p[0], scale=p[1])
q_pdf = norm.pdf(x, loc=q[0], scale=q[1])
div = div_func(p_pdf, q_pdf)
plt.plot(x, p_pdf, label='p')
plt.plot(x, q_pdf, label='q')
plt.legend()
plt.title(f'{divergence} divergence: {div:.3f}')
plt.show()
# Example usage
p = (0, 1)
q = (1, 2)
plt_divergence(p, q, divergence='kl')
```
阅读全文