in src/sagemaker_xgboost_container/distributed.py [0:0]
def __init__(self,
hosts,
current_host=None,
master_host=None,
port=None,
max_connect_attempts=None,
connect_retry_timeout=3):
"""Context manager for rabit initialization.
:param hosts: List of hostnames
:param current_host: Current hostname. If not provided, use 127.0.0.1.
:param master_host: Master host hostname. If not provided, use alphabetically first hostname amongst hosts
to ensure determinism in choosing master node.
:param port: Port to connect to master, if not specified use 9099.
:param max_connect_attempts: Number of times to try connecting to RabitTracker. If this arg is set
to None, try indefinitely.
:param connect_retry_timeout: Timeout value when attempting to connect to RabitTracker.
This will be ignored if max_connect_attempt is None
"""
# Get the host information. This is used to identify the master host
# that will run the RabitTracker and also to work out how many clients/slaves
# exist (this will ensure that all-reduce is set up correctly and that
# it blocks whilst waiting for those hosts to process the data).
if not current_host:
current_host = LOCAL_HOSTNAME
self.current_host = current_host
self.logger = self._get_logger(self.current_host)
self.logger.debug("Found current host.")
self.hosts = sorted(hosts)
self.n_workers = len(self.hosts)
self.logger.debug("Found hosts: {} [{}]".format(self.hosts, self.n_workers))
# We use the first lexicographically named host as the master if not indicated otherwise
if not master_host:
master_host = self.hosts[0]
self.master_host = master_host
self.is_master_host = self.current_host == self.master_host
self.logger.debug("Is Master: {}".format(self.is_master_host))
self.logger.debug("Master: {}".format(self.master_host))
# We start the RabitTracker on a known port on the first host. We can
# do this since SageMaker Training instances are single tenent and we
# don't need to worry about port contention.
if port is None:
port = 9099
self.logger.debug("No port specified using: {}".format(port))
else:
self.logger.debug("Using provided port: {}".format(port))
self.port = port
if max_connect_attempts is None or max_connect_attempts > 0:
self.max_connect_attempts = max_connect_attempts
else:
raise ValueError("max_connect_attempts must be None or an integer greater than 0.")
self.connect_retry_timeout = connect_retry_timeout