simclr torch
SimCLR pytorch 代码
SimCLR (Simple Framework for Contrastive Learning of Visual Representations) 是一种无监督的视觉表示学习方法,它通过对比样本对(通常是在数据增强后的版本)来训练深度神经网络。在PyTorch中实现SimCLR的基本步骤包括:
1. **数据预处理**:加载图像数据并应用随机数据增强,如翻转、裁剪和颜色调整,创建正负样本对。
import torch
from torchvision.transforms import transforms
transform = transforms.Compose([
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
def get_augmented_pairs(x):
x_pos = transform(x)
x_neg = transform(x)
return x_pos, x_neg
2. **构建模型**:使用PyTorch构建卷积神经网络(CNN),例如ResNet、ViT等作为基础模型。
import torch.nn as nn
import torchvision.models as models
model = models.resnet50(pretrained=False)
# 去掉分类层
num_features = model.fc.in_features
model.fc = nn.Identity()
3. **构建SimCLR架构**:这里通常包含一个投影头(projector),用于将来自原始模型的特征映射到更低维度的空间,并一个加权平均池化层(weight tying trick)。
class ProjectionHead(nn.Module):
def __init__(self, num_features, projection_dim):
self.projector = nn.Sequential(
nn.Linear(num_features, projection_dim),
nn.Linear(projection_dim, projection_dim),
def forward(self, x):
return self.projector(x)
projection_head = ProjectionHead(num_features, proj_dim=128)
4. **优化器设置**:使用AdamW或其他优化器,并设置学习率衰减策略(如warm-up)。
5. **训练循环**:对于每个批次的数据,计算正负样本对的嵌入之间的相似度损失,然后更新模型参数。
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
for epoch in range(num_epochs):
# 训练过程...
pos_embeddings, neg_embeddings = ... # 获取一对样本的正负嵌入
loss = ... # 使用某种对比性损失函数(如NT-Xent)
### SimCLR Loss Function Definition and Implementation
In self-supervised learning, particularly within contrastive learning frameworks like SimCLR, the objective is to learn representations that capture semantic information about data points without explicit labels. The core idea behind SimCLR involves maximizing agreement between differently augmented views of the same input while minimizing similarity with other inputs.
The loss function employed in SimCLR aims at pulling together positive pairs—two augmentations of the same image—and pushing apart negative pairs—augmentations from different images. This mechanism encourages robust feature extraction across various transformations applied during training[^2].
Mathematically, given an anchor sample \( x_i \), its corresponding augmentation \( x_j' \) forms a positive pair; all remaining samples form negatives. For each mini-batch containing N examples, including their respective augmentations, one computes pairwise similarities using cosine similarity:
\[ s_{ij} = \frac{z_i^\top z_j}{\|z_i\|\|z_j\|}, \]
where \( z_i \) represents normalized embeddings obtained after passing through projection head networks following backbone encoders for both original and transformed versions of images.
SimCLR employs InfoNCE (Noise Contrastive Estimation) as the basis for defining the final loss term over these computed scores:
\[ L(i,j)=−log(\frac{\exp(s_{i j}/τ)}{\sum _{k=1,k≠j}^{2N}\exp(s_{ik}/τ)}) ,\]
with temperature parameter τ controlling how sharply peaks concentrate around true matches versus spreading out among distractors. Lower values make decisions more confident yet riskier due to increased sensitivity towards noise or outliers present within batches processed concurrently during optimization steps taken via stochastic gradient descent methods implemented efficiently leveraging modern hardware accelerators such GPUs/TPUs supporting parallel computations effectively reducing overall runtime complexities significantly when handling large-scale datasets commonly encountered today's machine/deep-learning research landscape [^4].
Below demonstrates Python code implementing SimCLR’s loss calculation process utilizing PyTorch framework popularly adopted amongst practitioners working extensively upon computer vision tasks involving unsupervised representation learning paradigms:
import torch
from torch.nn.functional import normalize
def simclr_loss_fn(xis, xjs, temp):
batch_size = xis.shape[0]
# Normalize vectors into unit sphere space before computing dot products.
zis = normalize(xis, dim=-1)
zjs = normalize(xjs, dim=-1)
# Compute logits matrix excluding diagonal entries since they represent identical instances.
logits_ii =, zis.t()) / temp
logits_ij =, zjs.t()) / temp
mask = torch.eye(batch_size).to(logits_ii.device)
logits_ii = logits_ii * (1 - mask) - mask * 1e9 # Mask out self-comparisons
# Form full logits tensor stacking positives alongside negatives per row-wise fashion.
logits_ij_ji =[logits_ij, logits_ij.t()], dim=1)
exp_logits = torch.exp([logits_ii, logits_ij_ji], dim=1))
log_prob = logits_ij - torch.log(exp_logits.sum(dim=1, keepdim=True))
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
loss = -mean_log_prob_pos.mean()
return loss