月度归档:2020年03月

基于距离学习的经典二 Contrastive Loss

Contrastive Loss 公式

L=\frac{1}{2 N} \sum_{n=1}^{N} y d^{2}+(1-y) \max (\operatorname{margin}-d, 0)^{2}

其中d=||an−bn||2,代表两个样本特征的欧氏距离,y为两个样本是否匹配的标签,y=1代表两个样本相似或者匹配,y=0则代表不匹配,margin为设定的阈值,margin 作用类比上一篇 基于距离学习的经典一 Triplet Loss

lecun 在论文 Dimensionality Reduction by Learning an Invariant Mapping 中类比了弹簧,postive example 在公式前半部分,相当于是引力,希望它们的距离变小,negative example 在公式后半部分,negative example 你挤压得太近就有排斥力,希望它们的距离变大。

PyTroch 实现 Contrastive Loss

import torch
import torch.nn as nn
import torch.nn.functional as F

class ContrastiveLoss(nn.Module):
    """
    Contrastive loss
    Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
    """

    def __init__(self, margin):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.eps = 1e-9

    def forward(self, output1, output2, target, size_average=True):
        distances = (output2 - output1).pow(2).sum(1)  # squared distances
        losses = 0.5 * (target.float() * distances +
                        (1 + -1 * target).float() * F.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
        return losses.mean() if size_average else losses.sum()

TODO

  1. Online Hard Contrastive Loss

参考

Dimensionality Reduction by Learning an Invariant Mapping
sentence_transformers->ContrastiveLoss