in src/sagemaker_tensorflow_container/training.py [0:0]
def train(env, cmd_args):
"""Get training job environment from env and run the training job.
Args:
env (sagemaker_training.environment.Environment): Instance of Environment class
"""
parameter_server_enabled = (
env.additional_framework_parameters.get(SAGEMAKER_PARAMETER_SERVER_ENABLED, False)
and len(env.hosts) > 1
)
multi_worker_mirrored_strategy_enabled = env.additional_framework_parameters.get(
SAGEMAKER_MULTI_WORKER_MIRRORED_STRATEGY_ENABLED, False
)
sagemaker_distributed_dataparallel_enabled = env.additional_framework_parameters.get(
SAGEMAKER_DISTRIBUTED_DATAPARALLEL_ENABLED, False
)
env_vars = env.to_env_vars()
# Setup
if env.current_instance_group in env.distribution_instance_groups:
if parameter_server_enabled:
tf_config = _build_tf_config_for_ps(hosts=env.distribution_hosts, current_host=env.current_host)
logger.info("Running distributed training job with parameter servers")
elif multi_worker_mirrored_strategy_enabled:
env_vars["TF_CONFIG"] = json.dumps(
_build_tf_config_for_mwms(hosts=env.distribution_hosts, current_host=env.current_host)
)
logger.info("Running distributed training job with multi_worker_mirrored_strategy setup")
runner_type = runner.ProcessRunnerType
# Run
if parameter_server_enabled:
logger.info("Launching parameter server process")
_run_ps(env, tf_config["cluster"])
logger.info("Launching worker process")
_run_worker(env, cmd_args, tf_config)
if not _is_host_master(env.hosts, env.current_host):
_wait_until_master_is_down(env.hosts[0])
else:
if env.current_instance_group in env.distribution_instance_groups:
mpi_enabled = env.additional_framework_parameters.get("sagemaker_mpi_enabled")
if mpi_enabled:
runner_type = runner.MPIRunnerType
elif sagemaker_distributed_dataparallel_enabled:
runner_type = runner.SMDataParallelRunnerType
entry_point.run(
uri=env.module_dir,
user_entry_point=env.user_entry_point,
args=cmd_args,
env_vars=env_vars,
capture_error=True,
runner_type=runner_type,
)