def rabit_run()

in src/sagemaker_xgboost_container/distributed.py [0:0]


def rabit_run(exec_fun, args, include_in_training, hosts, current_host,
              first_port=None, second_port=None, max_connect_attempts=None,
              connect_retry_timeout=3, update_rabit_args=False):
    """Run execution function after initializing dmlc/rabit.

    This method initializes rabit twice:
        1. To broadcast to all hosts which hosts should be included in training.
        2. Run distributed xgb train() with just the hosts from above.

    :param exec_fun: Function to run while rabit is initialized. xgb.train() must run in the same process space
                    in order to utilize rabit initialization. Note that the execution function must also take the args
                    'is_distributed' and 'is_master'.
    :param args: Arguments to run execution function.
    :param include_in_training: Boolean if the current hosts should be used in training. This is done here so that
                                all the hosts in the cluster know which hosts to include during training.
    :param hosts:
    :param current_host:
    :param first_port: Port to use for the initial rabit initialization. If None, rabit defaults this to 9099
    :param second_port: Port to use for second rabit initialization. If None, this increments previous port by 1
    :param max_connect_attempts
    :param connect_retry_timeout
    :param update_rabit_args: Boolean to include rabit information to args. If True, the following is added:
                                is_master
    """
    with Rabit(
            hosts=hosts,
            current_host=current_host,
            port=first_port,
            max_connect_attempts=max_connect_attempts,
            connect_retry_timeout=connect_retry_timeout) as rabit:
        hosts_with_data = rabit.synchronize({'host': rabit.current_host, 'include_in_training': include_in_training})
        hosts_with_data = [record['host'] for record in hosts_with_data if record['include_in_training']]

        # Keep track of port used, so that hosts trying to shutdown know when server is not available
        previous_port = rabit.master_port

    if not include_in_training:
        logging.warning("Host {} not being used for distributed training.".format(current_host))
        sys.exit(0)

    second_rabit_port = second_port if second_port else previous_port + 1

    if len(hosts_with_data) > 1:
        # Set up rabit with nodes that have data and an unused port so that previous slaves don't confuse it
        # with the previous rabit configuration
        with Rabit(
                hosts=hosts_with_data,
                current_host=current_host,
                port=second_rabit_port,
                max_connect_attempts=max_connect_attempts,
                connect_retry_timeout=connect_retry_timeout) as cluster:
            if update_rabit_args:
                args.update({'is_master': cluster.is_master})
            exec_fun(**args)

    elif len(hosts_with_data) == 1:
        logging.debug("Only 1 host with training data, "
                      "starting single node training job from: {}".format(current_host))
        if update_rabit_args:
            args.update({'is_master': True})
        exec_fun(**args)

    else:
        raise RuntimeError("No hosts received training data.")