def on_epoch_end()

in tensorflow_similarity/callbacks.py [0:0]


    def on_epoch_end(self, epoch: int, logs: dict = None):
        """Computes the eval metrics at the end of each epoch.

        NOTE: This method resets the index and batch adds the target embeddings
        to the index using the new embeddings generated by the current version
        of the model.
        """
        _ = epoch
        if logs is None:
            logs = {}

        # reset the index
        self.model.reset_index()

        # rebuild the index
        self.model.index(self.targets, self.target_labels, verbose=0)

        known_results = _compute_classification_metrics(
            queries=self.queries_known,
            query_labels=self.query_labels_known,
            model=self.model,
            evaluator=self.evaluator,
            metrics=self.metrics,
            k=self.k,
            matcher=self.matcher,
            distance_thresholds=self.distance_thresholds,
        )

        unknown_results = _compute_classification_metrics(
            queries=self.queries_unknown,
            query_labels=self.query_labels_unknown,
            model=self.model,
            evaluator=self.evaluator,
            metrics=self.metrics,
            k=self.k,
            matcher=self.matcher,
            distance_thresholds=self.distance_thresholds,
        )

        mstr = []
        for metric_name, vals in known_results.items():
            float_val = vals[0]
            full_metric_name = f"{metric_name}_known_classes"
            logs[full_metric_name] = float_val
            mstr.append(f"{full_metric_name}: {float_val:0.4f}")
            if self.tb_writer:
                with self.tb_writer.as_default():
                    tf.summary.scalar(full_metric_name, float_val, step=epoch)

        for metric_name, vals in unknown_results.items():
            float_val = vals[0]
            full_metric_name = f"{metric_name}_unknown_classes"
            logs[full_metric_name] = float_val
            mstr.append(f"{full_metric_name}: {float_val:0.4f}")
            if self.tb_writer:
                with self.tb_writer.as_default():
                    tf.summary.scalar(full_metric_name, float_val, step=epoch)

        # reset the index to prevent users from accidently using this after the
        # callback
        self.model.reset_index()

        print(" - ".join(mstr))