def train()

in src/sagemaker_mxnet_container/training.py [0:0]


def train(env):
    logger.info('MXNet training environment: {}'.format(env.to_env_vars()))

    if env.additional_framework_parameters.get(LAUNCH_PS_ENV_NAME, False):
        _verify_hosts(env.hosts)

        ps_port = env.hyperparameters.get('_ps_port', '8000')
        ps_verbose = env.hyperparameters.get('_ps_verbose', '0')

        logger.info('Starting distributed training task')
        if scheduler_host(env.hosts) == env.current_host:
            _run_mxnet_process('scheduler', env.hosts, ps_port, ps_verbose)
        _run_mxnet_process('server', env.hosts, ps_port, ps_verbose)
        os.environ.update(_env_vars_for_role('worker', env.hosts, ps_port, ps_verbose))

    mpi_enabled = env.additional_framework_parameters.get(LAUNCH_MPI_ENV_NAME)

    if mpi_enabled:
        runner_type = runner.MPIRunnerType
    else:
        runner_type = runner.ProcessRunnerType

    entry_point.run(uri=env.module_dir,
                    user_entry_point=env.user_entry_point,
                    args=env.to_cmd_args(),
                    env_vars=env.to_env_vars(),
                    runner_type=runner_type)