in src/sagemaker_training/trainer.py [0:0]
def train():
"""The main function responsible for running training in the container."""
intermediate_sync = None
exit_code = SUCCESS_CODE
try:
env = environment.Environment()
region = os.environ.get("AWS_REGION", os.environ.get(params.REGION_NAME_ENV))
s3_endpoint_url = os.environ.get(params.S3_ENDPOINT_URL, None)
intermediate_sync = intermediate_output.start_sync(
env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url
)
if env.framework_module:
framework_name, entry_point_name = env.framework_module.split(":")
framework = importlib.import_module(framework_name)
# the logger is configured after importing the framework library, allowing
# the framework to configure logging at import time.
logging_config.configure_logger(env.log_level)
logger.info("Imported framework %s", framework_name)
entrypoint = getattr(framework, entry_point_name)
entrypoint()
else:
logging_config.configure_logger(env.log_level)
mpi_enabled = env.additional_framework_parameters.get(params.MPI_ENABLED)
runner_type = (
runner.RunnerType.MPI
if mpi_enabled and (env.current_instance_group in env.distribution_instance_groups)
else runner.RunnerType.Process
)
entry_point.run(
env.module_dir,
env.user_entry_point,
env.to_cmd_args(),
env.to_env_vars(),
runner_type=runner_type,
)
logger.info("Reporting training SUCCESS")
files.write_success_file()
except errors.ClientError as e:
failure_msg = str(e)
files.write_failure_file(failure_msg)
logger.error("Reporting training FAILURE")
logger.error(failure_msg)
if intermediate_sync:
intermediate_sync.join()
exit_code = DEFAULT_FAILURE_CODE
except Exception as e: # pylint: disable=broad-except
if any(path in traceback.format_exc() for path in SM_TRAINING_COMPILER_PATHS):
failure_msg = "SMTrainingCompiler Error: \n%s\n%s" % (traceback.format_exc(), str(e))
else:
failure_msg = "Framework Error: \n%s\n%s" % (traceback.format_exc(), str(e))
files.write_failure_file(failure_msg)
logger.error("Reporting training FAILURE")
logger.error(failure_msg)
error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE)
exit_code = _get_valid_failure_exit_code(error_number)
finally:
if intermediate_sync:
intermediate_sync.join()
_exit_processes(exit_code)