def update_state()

in tensorflow_similarity/training_metrics/distance_metrics.py [0:0]


    def update_state(self, labels, embeddings, sample_weight):

        # [distances]
        pairwise_distances = self.distance(embeddings)

        # [mask]
        batch_size = tf.size(labels)
        positive_mask, negative_mask = build_masks(labels, batch_size)

        if self.anchor == "positive":
            if self.positive_mining_strategy == "hard":
                distances, _ = masked_max(pairwise_distances, positive_mask)
            else:
                distances, _ = masked_min(pairwise_distances, positive_mask)
        else:
            if self.negative_mining_strategy == 'hard':
                distances, _ = masked_min(pairwise_distances, negative_mask)
            else:
                distances, _ = masked_max(pairwise_distances, negative_mask)

        # reduce
        if self.aggregate == 'mean' or self.aggregate == 'avg':
            aggregated_distances = tf.reduce_mean(distances)
        elif self.aggregate == 'max':
            aggregated_distances = tf.reduce_max(distances)
        elif self.aggregate == 'min':
            aggregated_distances = tf.reduce_min(distances)
        elif self.aggregate == 'sum':
            aggregated_distances = tf.reduce_sum(distances)

        self.aggregated_distances = aggregated_distances