in src/sagemaker_mxnet_container/training.py [0:0]
def train(env):
logger.info('MXNet training environment: {}'.format(env.to_env_vars()))
if env.additional_framework_parameters.get(LAUNCH_PS_ENV_NAME, False):
_verify_hosts(env.hosts)
ps_port = env.hyperparameters.get('_ps_port', '8000')
ps_verbose = env.hyperparameters.get('_ps_verbose', '0')
logger.info('Starting distributed training task')
if scheduler_host(env.hosts) == env.current_host:
_run_mxnet_process('scheduler', env.hosts, ps_port, ps_verbose)
_run_mxnet_process('server', env.hosts, ps_port, ps_verbose)
os.environ.update(_env_vars_for_role('worker', env.hosts, ps_port, ps_verbose))
mpi_enabled = env.additional_framework_parameters.get(LAUNCH_MPI_ENV_NAME)
if mpi_enabled:
runner_type = runner.MPIRunnerType
else:
runner_type = runner.ProcessRunnerType
entry_point.run(uri=env.module_dir,
user_entry_point=env.user_entry_point,
args=env.to_cmd_args(),
env_vars=env.to_env_vars(),
runner_type=runner_type)