decisionai_plugin/common/util/metric.py (42 lines of code) (raw):

import tensorflow as tf import json import logging from .meta import update_state logger = logging.getLogger(__name__) class Metric: def __init__(self, epochs, epoch_th, loss, valid_loss): self.__epochs = epochs self.__current_epoch = epoch_th self.__loss = loss self.__valid_loss = valid_loss @property def epochs(self): return self.__epochs @property def epoch(self): return self.__current_epoch @property def loss(self): return self.__loss @property def valid_loss(self): return self.__valid_loss class MetricSender: def __init__(self, config, subscription, model_id): self.__config = config self.__subscription = subscription self.__model_id = model_id pass def send(self, metric: Metric): txt = 'epoch: ' + str(metric.epoch) + '/' + str(metric.epochs) + ', loss: ' + str(metric.loss) + ', validate loss: ' + str(metric.valid_loss) update_state(self.__config, self.__subscription, self.__model_id, None, None, txt) info = {'epochs': metric.epochs, 'epoch': metric.epoch, 'loss': metric.loss, 'val_loss': metric.valid_loss} # logger.info("Current metric : {0}".format(json.dumps(info))) class MetricCollector(tf.keras.callbacks.Callback): def __init__(self, epochs, metric_sender: MetricSender): super().__init__() self.__epochs = epochs self.__sender = metric_sender def on_epoch_end(self, epoch, logs=None): metric = Metric(epochs=self.__epochs, epoch_th=epoch, loss=logs['loss'], valid_loss=logs['val_loss'] if 'val_loss' in logs else None) self.__sender.send(metric)