in tensorflow_similarity/losses/circle_loss.py [0:0]
def circle_loss(labels: IntTensor,
embeddings: FloatTensor,
distance: Callable,
gamma: float = 80,
margin: float = 0.4) -> Any:
"""Circle loss computations
The original paper used cosine similarity while this loss has been modified
to work with cosine distance.
Args:
labels: Labels associated with the embeddings
embeddings: Embeddings as infered by the model.
distance: Which distance function to use to compute the pairwise
distances between embeddings. The distance is expected to be
between [0, 2]. Defaults to 'cosine'.
gamma: Scaling term. Defaults to 80. Note: Large values cause the
LogSumExp to return the Max pair and reduces the weighted mixing of all
pairs. Should be hypertuned.
margin: Used to weight the distance. Below this distance, negatives are
up weighted and positives are down weighted. Similarly, above this
distance negatives are down weighted and positive are up weighted.
Defaults to 0.4.
Returns:
Loss: The loss value for the current batch.
"""
# Switched from what's in the paper to work with distance instead of
# similarity.
optim_pos = margin
optim_neg = 1 + margin
delta_pos = margin
delta_neg = 1 - margin
# label
batch_size = tf.size(labels)
# [distances]
pairwise_distances = distance(embeddings)
# [masks] -> filter to keep only the relevant value - zero the rest
positive_mask, negative_mask = build_masks(labels, batch_size)
valid_anchors = tf.math.logical_and(
tf.math.reduce_any(positive_mask, axis=1),
tf.math.reduce_any(negative_mask, axis=1)
)
# Cast masks as floats to support multiply
valid_anchors = tf.cast(valid_anchors, dtype='float32')
positive_mask = tf.cast(positive_mask, dtype='float32')
negative_mask = tf.cast(negative_mask, dtype='float32')
# [weights] from (5) in 3.1 using optim values of 3.2
# Implementation note: we do all the computation on the full pairwise and
# filter at then end to keep only relevant values.
# positive weights
pos_weights = optim_pos + pairwise_distances # (5) in 3.1
pos_weights = pos_weights * positive_mask # filter
pos_weights = tf.maximum(pos_weights, 0.0) # clip at zero
# negative weights
neg_weights = optim_neg - pairwise_distances # (5) in 3.1
neg_weights = neg_weights * negative_mask # filter
neg_weights = tf.maximum(neg_weights, 0.0) # clip at zero
# Subtract the between and within class margins
pos_dists = delta_pos - pairwise_distances
neg_dists = delta_neg - pairwise_distances
# distances filtering
# /2 because we have a pairwise so each distance is counted twice
# applying weights as in (4) in 3.1
pos_wdists = (-1 * gamma * pos_weights * pos_dists) # / 2
neg_wdists = (gamma * neg_weights * neg_dists) # / 2
p_loss = logsumexp(pos_wdists, positive_mask)
n_loss = logsumexp(neg_wdists, negative_mask)
# Remove any anchors that have empty neg or pos pairs.
# NOTE: reshape is required here because valid_anchors is [m] and
# p_loss + n_loss is [m, 1].
circle_loss = tf.math.multiply(
p_loss + n_loss,
tf.reshape(valid_anchors, (-1, 1))
)
return circle_loss