Numpy 实现余弦相似度embedding 的NTxent
时间: 2024-02-09 19:09:18 浏览: 76
下面是使用 Numpy 实现余弦相似度的 NTxent 损失的示例代码。
首先,我们需要定义一个嵌入矩阵 $E$,其中每一行代表一个单词的嵌入向量。假设我们有 $N$ 个单词,每个单词的嵌入向量为 $d$ 维,那么 $E$ 的形状将为 $N \times d$。我们可以使用 NumPy 的随机函数生成一个随机的嵌入矩阵:
```python
import numpy as np
N = 10000
d = 300
E = np.random.randn(N, d)
```
接下来,我们需要选择一些中心单词 $c_i$,以及它们对应的正样本单词 $p_i$ 和负样本单词 $n_i$。我们可以使用 NumPy 的随机函数从嵌入矩阵 $E$ 中随机选择这些单词:
```python
batch_size = 32
c_idx = np.random.randint(N, size=batch_size)
e_c = E[c_idx]
p_idx = np.random.randint(N, size=batch_size)
e_p = E[p_idx]
n_idx = np.random.randint(N, size=(batch_size, 5))
for i in range(batch_size):
while p_idx[i] in n_idx[i]:
n_idx[i] = np.random.randint(N, size=5)
e_n = E[n_idx]
```
接下来,我们可以计算每个中心单词 $c_i$ 和对应的正样本单词 $p_i$ 的余弦相似度:
```python
cos_sim = np.sum(e_c * e_p, axis=1) / (np.linalg.norm(e_c, axis=1) * np.linalg.norm(e_p, axis=1))
```
然后,我们可以计算每个中心单词 $c_i$ 和对应的负样本单词 $n_{i,j}$ 的余弦相似度:
```python
cos_sim_neg = np.sum(e_c[:, np.newaxis, :] * e_n, axis=2) / (np.linalg.norm(e_c, axis=1)[:, np.newaxis] * np.linalg.norm(e_n, axis=2))
```
接下来,我们需要将余弦相似度转换为概率分布,并计算 NTxent 损失。假设我们使用 softmax 函数将余弦相似度转换为概率分布,我们可以使用以下代码计算 NTxent 损失:
```python
temperature = 0.1
logit = cos_sim / temperature
logit_neg = cos_sim_neg / temperature
logit_all = np.concatenate([np.array([logit]), logit_neg], axis=0)
logit_all = np.exp(logit_all)
probs = logit_all / np.sum(logit_all, axis=0)
log_prob = np.log(probs[0] / np.sum(probs[1:], axis=0))
loss = -np.mean(log_prob)
```
其中,`temperature` 是一个超参数。我们可以使用类似的方式计算每个中心单词 $c_i$ 和对应的负样本单词 $n_{i,j}$ 的 NTxent 损失。完整的代码示例如下:
```python
import numpy as np
N = 10000
d = 300
temperature = 0.1
# 生成随机嵌入矩阵
E = np.random.randn(N, d)
# 选择中心单词和正样本单词
batch_size = 32
c_idx = np.random.randint(N, size=batch_size)
e_c = E[c_idx]
p_idx = np.random.randint(N, size=batch_size)
e_p = E[p_idx]
# 选择负样本单词
n_idx = np.random.randint(N, size=(batch_size, 5))
for i in range(batch_size):
while p_idx[i] in n_idx[i]:
n_idx[i] = np.random.randint(N, size=5)
e_n = E[n_idx]
# 计算余弦相似度
cos_sim = np.sum(e_c * e_p, axis=1) / (np.linalg.norm(e_c, axis=1) * np.linalg.norm(e_p, axis=1))
cos_sim_neg = np.sum(e_c[:, np.newaxis, :] * e_n, axis=2) / (np.linalg.norm(e_c, axis=1)[:, np.newaxis] * np.linalg.norm(e_n, axis=2))
# 计算 NTxent 损失
logit = cos_sim / temperature
logit_neg = cos_sim_neg / temperature
logit_all = np.concatenate([np.array([logit]), logit_neg], axis=0)
logit_all = np.exp(logit_all)
probs = logit_all / np.sum(logit_all, axis=0)
log_prob = np.log(probs[0] / np.sum(probs[1:], axis=0))
loss = -np.mean(log_prob)
```
阅读全文