in tensorflow_similarity/callbacks.py [0:0]
def on_epoch_end(self, epoch: int, logs: dict = None):
"""Computes the eval metrics at the end of each epoch.
NOTE: This method resets the index and batch adds the target embeddings
to the index using the new embeddings generated by the current version
of the model.
"""
_ = epoch
if logs is None:
logs = {}
# reset the index
self.model.reset_index()
# rebuild the index
self.model.index(self.targets, self.target_labels, verbose=0)
known_results = _compute_classification_metrics(
queries=self.queries_known,
query_labels=self.query_labels_known,
model=self.model,
evaluator=self.evaluator,
metrics=self.metrics,
k=self.k,
matcher=self.matcher,
distance_thresholds=self.distance_thresholds,
)
unknown_results = _compute_classification_metrics(
queries=self.queries_unknown,
query_labels=self.query_labels_unknown,
model=self.model,
evaluator=self.evaluator,
metrics=self.metrics,
k=self.k,
matcher=self.matcher,
distance_thresholds=self.distance_thresholds,
)
mstr = []
for metric_name, vals in known_results.items():
float_val = vals[0]
full_metric_name = f"{metric_name}_known_classes"
logs[full_metric_name] = float_val
mstr.append(f"{full_metric_name}: {float_val:0.4f}")
if self.tb_writer:
with self.tb_writer.as_default():
tf.summary.scalar(full_metric_name, float_val, step=epoch)
for metric_name, vals in unknown_results.items():
float_val = vals[0]
full_metric_name = f"{metric_name}_unknown_classes"
logs[full_metric_name] = float_val
mstr.append(f"{full_metric_name}: {float_val:0.4f}")
if self.tb_writer:
with self.tb_writer.as_default():
tf.summary.scalar(full_metric_name, float_val, step=epoch)
# reset the index to prevent users from accidently using this after the
# callback
self.model.reset_index()
print(" - ".join(mstr))