in src/sagemaker_xgboost_container/dmlc_patch/tracker.py [0:0]
def accept_slaves(self, nslave):
# set of nodes that finishs the job
shutdown = {}
# set of nodes that is waiting for connections
wait_conn = {}
# maps job id to rank
job_map = {}
# list of workers that is pending to be assigned rank
pending = []
# lazy initialize tree_map
tree_map = None
logger.debug("Looking for {} connections.".format(nslave))
while len(shutdown) != nslave:
fd, s_addr = self.sock.accept()
logger.debug("Accepted connection")
logger.debug(fd)
logger.debug(s_addr)
try:
s = SlaveEntry(fd, s_addr)
except socket.timeout as ex:
logger.info("No data received from connection {}. Closing.".format(s_addr))
continue
logger.debug("Slave command is: {}".format(s.cmd))
if s.cmd == "print":
msg = s.sock.recvstr()
logger.debug("PRINTING FROM {}:{}".format(fd, s_addr))
logger.info(msg.strip())
continue
if s.cmd == "shutdown":
assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn
shutdown[s.rank] = s
logger.debug("Recieve %s signal from %d", s.cmd, s.rank)
continue
assert s.cmd == "start" or s.cmd == "recover"
# lazily initialize the slaves
if tree_map is None:
assert s.cmd == "start"
if s.world_size > 0:
nslave = s.world_size
tree_map, parent_map, ring_map = self.get_link_map(nslave)
# set of nodes that is pending for getting up
todo_nodes = list(range(nslave))
else:
assert s.world_size == -1 or s.world_size == nslave
if s.cmd == "recover":
assert s.rank >= 0
rank = s.decide_rank(job_map)
# batch assignment of ranks
if rank == -1:
assert len(todo_nodes) != 0
pending.append(s)
logger.debug("Pending slaves: {}".format(pending))
logger.debug("TO-do slaves: {}".format(todo_nodes))
if len(pending) == len(todo_nodes):
pending.sort(key=lambda x: x.host)
for s in pending:
rank = todo_nodes.pop(0)
if s.jobid != "NULL":
job_map[s.jobid] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
logger.info("Recieve %s signal from %s; assign rank %d", s.cmd, s.host, s.rank)
if len(todo_nodes) == 0:
logger.info("@tracker All of %d nodes getting started", nslave)
self.start_time = time.time()
else:
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
logger.debug("Recieve %s signal from %d", s.cmd, s.rank)
if s.wait_accept > 0:
wait_conn[rank] = s
logger.info("@tracker All nodes finishes job")
self.end_time = time.time()
logger.info("@tracker %s secs between node start and job finish", str(self.end_time - self.start_time))