def get_tensorflow_exception_classes()

in src/sagemaker_training/process.py [0:0]


def get_tensorflow_exception_classes():
    """TensorFlow exception classes are reused by XLA. XLA is present in SageMaker Training Compiler
    enabled TensorFlow and PyTorch DLCs."""
    exception_classes = []
    try:
        from tensorflow.python.framework import errors_impl

        # list of exceptions from TensorFlow that sagemaker-training-toolkit to catch and log
        exception_classes += [name for name, obj in getmembers(errors_impl) if isclass(obj)]
        # adding XlaRuntimeError as a str (process.watch can handle str) as there is
        # no proper import of module tensorflow/compiler/xla/python/xla_client.py available.
        exception_classes += ["XlaRuntimeError"]
    except ImportError:
        logger.info("Exceptions not imported for SageMaker TF as Tensorflow is not installed.")

    if not exception_classes:
        exception_classes = [DEFAULT_ERROR_CLASS]
    return exception_classes