in tensorflow_similarity/training_metrics/distance_metrics.py [0:0]
def update_state(self, labels, embeddings, sample_weight):
# [distances]
pairwise_distances = self.distance(embeddings)
# [mask]
batch_size = tf.size(labels)
positive_mask, negative_mask = build_masks(labels, batch_size)
if self.anchor == "positive":
if self.positive_mining_strategy == "hard":
distances, _ = masked_max(pairwise_distances, positive_mask)
else:
distances, _ = masked_min(pairwise_distances, positive_mask)
else:
if self.negative_mining_strategy == 'hard':
distances, _ = masked_min(pairwise_distances, negative_mask)
else:
distances, _ = masked_max(pairwise_distances, negative_mask)
# reduce
if self.aggregate == 'mean' or self.aggregate == 'avg':
aggregated_distances = tf.reduce_mean(distances)
elif self.aggregate == 'max':
aggregated_distances = tf.reduce_max(distances)
elif self.aggregate == 'min':
aggregated_distances = tf.reduce_min(distances)
elif self.aggregate == 'sum':
aggregated_distances = tf.reduce_sum(distances)
self.aggregated_distances = aggregated_distances