def train()

in src/sagemaker_pytorch_container/training.py [0:0]


def train(training_environment):
    """Run PyTorch training on a user supplied module.

    The user supplied module is run in either a local or distributed SageMaker
    environment.

    The user supplied module and its dependencies are downloaded from S3.
    Training is invoked by calling a "train" function in the user supplied module.
    if the environment contains multiple hosts, then a distributed learning
    task is started.

    Args:
        training_environment: training environment object containing environment
            variables, training arguments and hyperparameters.
    """
    _sm_studio_local_mode = os.environ.get("SM_STUDIO_LOCAL_MODE", "False").lower() == "true"

    if not _sm_studio_local_mode:
        # Block until all host DNS lookups succeed. Relies on retrying dns_lookup.
        logger.info('Block until all host DNS lookups succeed.')
        for host in training_environment.hosts:
            _dns_lookup(host)
    else:
        logger.info('Bypass DNS check in case of Studio Local Mode execution.')

    _set_nccl_environment(training_environment.network_interface_name)

    _set_distributed_environment(training_environment)

    mpi_enabled = training_environment.additional_framework_parameters.get(LAUNCH_MPI_ENV_NAME)

    pytorch_ddp_enabled = training_environment.additional_framework_parameters.get(
        LAUNCH_PYTORCH_DDP_ENV_NAME, False
    )

    smdataparallel_enabled = training_environment.additional_framework_parameters.get(
        LAUNCH_SMDATAPARALLEL_ENV_NAME, False
    )

    pytorch_xla_enabled = training_environment.additional_framework_parameters.get(
        LAUNCH_PYTORCH_XLA_ENV_NAME, False
    )

    torch_distributed_enabled = training_environment.additional_framework_parameters.get(
        LAUNCH_TORCH_DISTRIBUTED_ENV_NAME, False
    )
    # default scenario
    runner_type = runner.ProcessRunnerType

    if training_environment.current_instance_group in training_environment.distribution_instance_groups:
        if mpi_enabled:
            runner_type = runner.MPIRunnerType
        elif pytorch_ddp_enabled:
            runner_type = runner.SMDataParallelRunnerType
            logger.info('Invoking SMDataParallel for native PT DDP job')
        elif torch_distributed_enabled:
            runner_type = runner.TorchDistributedRunnerType
            logger.info('Invoking TorchDistributed...')
        elif smdataparallel_enabled:
            runner_type = runner.SMDataParallelRunnerType
            logger.info('Invoking SMDataParallel')
        elif pytorch_xla_enabled:
            runner_type = runner.PyTorchXLARunnerType
            logger.info('Invoking PT-XLA Runner')
    logger.info('Invoking user training script.')

    # get capture_error from framework parameters
    capture_error = True
    if training_environment.additional_framework_parameters.get("sagemaker_toolkit_native_launcher_enabled"):
        capture_error = False
        logger.info(f'capture_error is {capture_error}. Default is True')

    _set_torch_version_environment()
    try:
        entry_point.run(uri=training_environment.module_dir,
                        user_entry_point=training_environment.user_entry_point,
                        args=training_environment.to_cmd_args(),
                        env_vars=training_environment.to_env_vars(),
                        capture_error=capture_error,
                        runner_type=runner_type)
    except errors.ExecuteUserScriptError as err:
        message = str(err)
        if message.find('terminate called after throwing an instance of \'gloo::EnforceNotMet\'') > -1:
            logger.warn('Known exception: {}'.format(message))
        else:
            info = sys.exc_info()
            six.reraise(info[0], err, info[2])