def get_modelparallel_exception_classes()

in src/sagemaker_training/mpi.py [0:0]


def get_modelparallel_exception_classes():
    """Set exception classes"""
    exception_classes = []
    try:
        from smdistributed.modelparallel.backend import exceptions

        # list of exceptions SMMP wants training toolkit to catch and log
        exception_classes += [x for x in dir(exceptions) if isclass(getattr(exceptions, x))]
    except ImportError:
        logger.info("No exception classes found in smdistributed.modelparallel.backend")

    try:
        from smdistributed.modelparallel.torch import exceptions as torch_exceptions

        # list of torch exceptions SMMP wants training toolkit to catch and log
        exception_classes += [
            ex for ex in dir(torch_exceptions) if isclass(getattr(torch_exceptions, ex))
        ]
    except ImportError:
        logger.info("No torch exception classes found in smdistributed.modelparallel.torch")

    if not exception_classes:
        exception_classes = [DEFAULT_ERROR_CLASS]
    return exception_classes