def multisimilarity_loss()

in tensorflow_similarity/losses/multisim_loss.py [0:0]


def multisimilarity_loss(labels: IntTensor,
                         embeddings: FloatTensor,
                         distance: Callable,
                         alpha: float = 2.0,
                         beta: float = 40,
                         epsilon: float = 0.2,
                         lmda: float = 1.0) -> Any:
    """Multi Similarity loss computations

    Args:
        labels: labels associated with the embed.

        embeddings: Embedded examples.

        distance: Which distance function to use to compute the pairwise.

        alpha: The exponential weight for the positive pairs. Increasing alpha
        makes the logsumexp softmax closer to the max positive pair distance,
        while decreasing it makes it closer to max(P) + log(batch_size).

        beta: The exponential weight for the negative pairs. Increasing beta
        makes the logsumexp softmax closer to the max negative pair distance,
        while decreasing it makes the softmax closer to
        max(N) + log(batch_size).

        epsilon: Used to remove easy positive and negative pairs. We only keep
        positives that we greater than the (smallest negative pair - epsilon)
        and we only keep negatives that are less than the
        (largest positive pair + epsilon).

        lmda: Used to weight the distance. Below this distance, negatives are
        up weighted and positives are down weighted. Similarly, above this
        distance negatives are down weighted and positive are up weighted.

    Returns:
        Loss: The loss value for the current batch.
    """
    # [Label]
    # ! Weirdness to be investigated
    # do not remove this code. It is actually needed for specific situation
    # Reshape label tensor to [batch_size, 1] if not already in that format.
    # labels = tf.reshape(labels, (labels.shape[0], 1))
    batch_size = tf.size(labels)

    # [distances]
    pairwise_distances = distance(embeddings)

    # [masks]
    positive_mask, negative_mask = build_masks(labels, batch_size)

    # [pair mining using Similarity-P]
    # This is essentially hard mining the negative and positive pairs.

    # Keep all positives > Min(neg_dist - epsilon).
    neg_min, _ = masked_min(pairwise_distances, negative_mask)
    neg_min = tf.math.subtract(neg_min, epsilon)
    pos_sim_p_mask = tf.math.greater(pairwise_distances, neg_min)
    pos_sim_p_mask = tf.math.logical_and(pos_sim_p_mask, positive_mask)

    # Keep all negatives < Max(pos_dist + epsilon).
    pos_max, _ = masked_max(pairwise_distances, positive_mask)
    pos_max = tf.math.add(pos_max, epsilon)
    neg_sim_p_mask = tf.math.less(pairwise_distances, pos_max)
    neg_sim_p_mask = tf.math.logical_and(neg_sim_p_mask, negative_mask)

    # Mark all pairs where we have both valid negative and positive pairs.
    valid_anchors = tf.math.logical_and(
            tf.math.reduce_any(pos_sim_p_mask, axis=1),
            tf.math.reduce_any(neg_sim_p_mask, axis=1)
    )

    # Cast masks as floats to support multiply
    valid_anchors = tf.cast(valid_anchors, dtype='float32')
    pos_sim_p_mask_f32 = tf.cast(pos_sim_p_mask, dtype='float32')
    neg_sim_p_mask_f32 = tf.cast(neg_sim_p_mask, dtype='float32')

    # [Weight the remaining pairs using Similarity-S and Similarity-N]
    shifted_distances = pairwise_distances - lmda
    pos_dists = alpha * shifted_distances
    neg_dists = -1 * beta * shifted_distances

    # [compute loss]

    # Positive pairs with a distance above 0 will be up weighted.
    p_loss = logsumexp(pos_dists, pos_sim_p_mask_f32)
    p_loss = p_loss / alpha

    # Negative pairs with a distance below 0 will be up weighted.
    n_loss = logsumexp(neg_dists, neg_sim_p_mask_f32)
    n_loss = n_loss / beta

    # Remove any anchors that have empty neg or pos pairs.
    # NOTE: reshape is required here because valid_anchors is [m] and
    #       p_loss + n_loss is [m, 1].
    multisim_loss = tf.math.multiply(
            p_loss + n_loss,
            tf.reshape(valid_anchors, (-1, 1))
    )

    return multisim_loss