in src/sagemaker_training/mpi.py [0:0]
def get_modelparallel_exception_classes():
"""Set exception classes"""
exception_classes = []
try:
from smdistributed.modelparallel.backend import exceptions
# list of exceptions SMMP wants training toolkit to catch and log
exception_classes += [x for x in dir(exceptions) if isclass(getattr(exceptions, x))]
except ImportError:
logger.info("No exception classes found in smdistributed.modelparallel.backend")
try:
from smdistributed.modelparallel.torch import exceptions as torch_exceptions
# list of torch exceptions SMMP wants training toolkit to catch and log
exception_classes += [
ex for ex in dir(torch_exceptions) if isclass(getattr(torch_exceptions, ex))
]
except ImportError:
logger.info("No torch exception classes found in smdistributed.modelparallel.torch")
if not exception_classes:
exception_classes = [DEFAULT_ERROR_CLASS]
return exception_classes