def contrastive_loss()

in src/similarity/siamese.py [0:0]


def contrastive_loss(distance, labels):
    
    is_diff = (labels > 0.0).float()
    loss = torch.mean(((1-is_diff) * torch.pow(distance, 2)) +
                        ((is_diff) * torch.pow(torch.abs(labels - distance), 2)))
        
    return loss