def train()

in src/sagemaker_tensorflow_container/training.py [0:0]


def train(env, cmd_args):
    """Get training job environment from env and run the training job.

    Args:
        env (sagemaker_training.environment.Environment): Instance of Environment class
    """
    parameter_server_enabled = (
        env.additional_framework_parameters.get(SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
        and len(env.hosts) > 1
    )
    multi_worker_mirrored_strategy_enabled = env.additional_framework_parameters.get(
        SAGEMAKER_MULTI_WORKER_MIRRORED_STRATEGY_ENABLED, False
    )
    sagemaker_distributed_dataparallel_enabled = env.additional_framework_parameters.get(
        SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED, False
    )

    env_vars = env.to_env_vars()

    # Setup
    if env.current_instance_group in env.distribution_instance_groups:
        if parameter_server_enabled:

            tf_config = _build_tf_config_for_ps(hosts=env.distribution_hosts, current_host=env.current_host)
            logger.info("Running distributed training job with parameter servers")

        elif multi_worker_mirrored_strategy_enabled:

            env_vars["TF_CONFIG"] = json.dumps(
                _build_tf_config_for_mwms(hosts=env.distribution_hosts, current_host=env.current_host)
            )
            logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")

    runner_type = runner.ProcessRunnerType

    # Run
    if parameter_server_enabled:

        logger.info("Launching parameter server process")
        _run_ps(env, tf_config["cluster"])
        logger.info("Launching worker process")
        _run_worker(env, cmd_args, tf_config)

        if not _is_host_master(env.hosts, env.current_host):
            _wait_until_master_is_down(env.hosts[0])

    else:
        if env.current_instance_group in env.distribution_instance_groups:
            mpi_enabled = env.additional_framework_parameters.get("sagemaker_mpi_enabled")

            if mpi_enabled:
                runner_type = runner.MPIRunnerType
            elif sagemaker_distributed_dataparallel_enabled:
                runner_type = runner.SMDataParallelRunnerType

        entry_point.run(
            uri=env.module_dir,
            user_entry_point=env.user_entry_point,
            args=cmd_args,
            env_vars=env_vars,
            capture_error=True,
            runner_type=runner_type,
        )