simsiam损失函数
时间: 2025-01-02 21:32:00 浏览: 13
### Simsiam 模型中的损失函数
Simsiam模型采用了负对数余弦相似度作为损失函数[^2]。具体来说,该损失函数用于衡量两个特征表示\(q_1\)和\(z_2\)以及\(q_2\)和\(z_1\)间的差异程度。
#### 负对数余弦相似度定义
余弦相似度测量的是两个非零向量之间夹角的余弦值,在这里被用来评估由编码器产生的不同视角下的样本特征之间的关系。对于一对输入图像的不同增强视图,分别记作\(x_1\)和\(x_2\),经过编码器处理后获得对应的预测向量\(p_1, p_2\)和投影向量\(z_1,z_2\)。损失函数可以表达如下:
\[L(p_i, z_j)=-\log{\left(\frac{\exp{(sim(p_i, z_j)/τ)}}{\sum_k \exp{(sim(p_i, z_k)/τ)}}\right)}\]
其中,
- \(i,j∈\{1,2\}\),且\(i≠j\);
- \(sim(a,b)\)代表a与b间标准化后的点积运算;
- τ是一个温度参数,通常设为0.5;
值得注意的是,尽管上述公式看起来像是对比损失的形式,但实际上由于没有引入负样本参与计算,因此并不属于真正的对比损失范畴。相反,这是通过对称形式来实现自我监督学习的一种方式[^5]。
```python
import torch.nn.functional as F
def cosine_similarity_loss(p, z):
# Normalize the vectors to unit length.
p_norm = F.normalize(p, dim=1)
z_norm = F.normalize(z, dim=1)
# Compute similarity between normalized vectors.
sim = (p_norm * z_norm).sum(dim=1)
# Return negative log of softmax over similarities.
return -(F.log_softmax(sim.unsqueeze(1), dim=1)).mean()
```
阅读全文