cli/jobs/pipelines/tensorflow-image-segmentation/src/tf_helper/profiling.py (65 lines of code) (raw):
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
This script provides some helper code to help with profiling tensorflow training.
"""
import time
import logging
import mlflow
from tensorflow import keras
import tensorflow
from common.profiling import LogTimeOfIterator
class CustomCallbacks(keras.callbacks.Callback):
"""To use during model.fit()"""
def __init__(self, enabled=True):
self.logger = logging.getLogger(__name__)
self.metrics = {}
self.train_start = None
self.epoch_start = None
self.epoch_end = time.time() # required for 1st epoch_init_time
self.test_start = None
self.enabled = enabled
def _flush(self):
self.logger.info(f"MLFLOW: metrics={self.metrics}")
if self.enabled:
mlflow.log_metrics(self.metrics)
def on_epoch_begin(self, epoch, logs=None):
self.metrics["epoch_init_time"] = time.time() - self.epoch_end
keys = list(logs.keys())
self.logger.info(
"Start epoch {} of training; got log keys: {}".format(epoch, keys)
)
self.epoch_start = time.time()
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
self.logger.info(
"End epoch {} of training; got log keys: {}".format(epoch, keys)
)
epoch_time = time.time() - self.epoch_start
self.metrics["epoch_train_time"] = epoch_time - self.metrics["epoch_eval_time"]
# add epoch metrics
for key in logs:
# align with our naming conventions
if key.startswith("val_"):
self.metrics[f"epoch_valid_{key[4:]}"] = logs[key]
else:
self.metrics[f"epoch_train_{key}"] = logs[key]
self.epoch_end = time.time()
self._flush()
def on_test_begin(self, logs=None):
keys = list(logs.keys())
self.logger.info("Start testing; got log keys: {}".format(keys))
self.test_start = time.time()
def on_test_end(self, logs=None):
keys = list(logs.keys())
self.logger.info("Stop testing; got log keys: {}".format(keys))
self.metrics["epoch_eval_time"] = time.time() - self.test_start
def on_train_begin(self, logs=None):
keys = list(logs.keys())
self.logger.info("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
self.logger.info("Stop training; got log keys: {}".format(keys))
class LogTimeOfTensorFlowIterator(LogTimeOfIterator):
"""This class is intended to "wrap" an existing Iterator
and log metrics for each next() call"""
def as_tf_dataset(self):
"""Constructs this as a tensorflow dataset"""
if self.enabled:
def _generator():
return self
return tensorflow.data.Dataset.from_generator(
_generator,
# works only if wrapped_sequence is already a tf.data.Dataset
output_signature=self.wrapped_sequence.element_spec,
)
else:
return self.wrapped_sequence