easy_rec/python/loss/circle_loss.py (41 lines of code) (raw):

# coding=utf-8 # Copyright (c) Alibaba, Inc. and its affiliates. import tensorflow as tf if tf.__version__ >= '2.0': tf = tf.compat.v1 def circle_loss(embeddings, labels, sessions=None, margin=0.25, gamma=32, embed_normed=False): """Paper: Circle Loss: A Unified Perspective of Pair Similarity Optimization. Link: http://arxiv.org/pdf/2002.10857.pdf Args: embeddings: A `Tensor` with shape [batch_size, embedding_size]. The embedding of each sample. labels: a `Tensor` with shape [batch_size]. e.g. click or not click in the session. sessions: a `Tensor` with shape [batch_size]. session ids of each sample. margin: the margin between positive similarity and negative similarity gamma: parameter of circle loss embed_normed: bool, whether input embeddings l2 normalized """ norm_embeddings = embeddings if embed_normed else tf.nn.l2_normalize( embeddings, axis=-1) pair_wise_cosine_matrix = tf.matmul( norm_embeddings, norm_embeddings, transpose_b=True) positive_mask = get_anchor_positive_triplet_mask(labels, sessions) negative_mask = 1 - positive_mask - tf.eye(tf.shape(labels)[0]) delta_p = 1 - margin delta_n = margin ap = tf.nn.relu(-tf.stop_gradient(pair_wise_cosine_matrix * positive_mask) + 1 + margin) an = tf.nn.relu( tf.stop_gradient(pair_wise_cosine_matrix * negative_mask) + margin) logit_p = -ap * (pair_wise_cosine_matrix - delta_p) * gamma * positive_mask - (1 - positive_mask) * 1e12 logit_n = an * (pair_wise_cosine_matrix - delta_n) * gamma * negative_mask - (1 - negative_mask) * 1e12 joint_neg_loss = tf.reduce_logsumexp(logit_n, axis=-1) joint_pos_loss = tf.reduce_logsumexp(logit_p, axis=-1) loss = tf.nn.softplus(joint_neg_loss + joint_pos_loss) return tf.reduce_mean(loss) def get_anchor_positive_triplet_mask(labels, sessions=None): """Return a 2D mask where mask[a, p] is 1.0 iff a and p are distinct and have same session and label. Args: labels: a `Tensor` with shape [batch_size] sessions: a `Tensor` with shape [batch_size] Returns: mask: tf.float32 `Tensor` with shape [batch_size, batch_size] """ # Check that i and j are distinct indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool) indices_not_equal = tf.logical_not(indices_equal) # Check if labels[i] == labels[j] # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1) labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1)) # Check if labels[i] == labels[j] if sessions is None or sessions is labels: class_equal = labels_equal else: sessions_equal = tf.equal( tf.expand_dims(sessions, 0), tf.expand_dims(sessions, 1)) class_equal = tf.logical_and(sessions_equal, labels_equal) # Combine the three masks mask = tf.logical_and(indices_not_equal, class_equal) return tf.cast(mask, tf.float32)