in src/sagemaker_training/smdataparallel.py [0:0]
def _create_command(self):
"""Create mpi-based smddprun command.
Based on the number of hosts, smddprun command differs.
Single-node: SMDATAPARALLEL_USE_SINGLENODE flag set to 1
Multi-node: SMDATAPARALLEL_USE_HOMOGENEOUS flag set to 1
"""
host_list = self._hosts
num_hosts = len(self._hosts)
num_processes = self._processes_per_host * num_hosts
logger.info("Network interface name: %s" % self._network_interface_name)
logger.info("Host: %s" % self._hosts)
if num_hosts > 1:
# multi-node; use homogeneous
# homogeneous mode uses 16 processes per host; 8 server; 8 worker
smdataparallel_server_addr = self._master_hostname
smdataparallel_server_port = 7592
host_list = ["{}:{}".format(host, self._processes_per_host) for host in self._hosts]
smdataparallel_flag = "SMDATAPARALLEL_USE_HOMOGENEOUS=1"
command = self._get_mpirun_command(
num_hosts,
host_list,
smdataparallel_flag,
num_processes,
smdataparallel_server_addr,
smdataparallel_server_port,
)
else:
# single-node
smdataparallel_flag = "SMDATAPARALLEL_USE_SINGLENODE=1"
command = self._get_mpirun_command(
num_hosts, host_list, smdataparallel_flag, num_processes
)
msg = "Env Hosts: %s Hosts: %s process_per_hosts: %s num_processes: %s"
logger.info(msg, self._hosts, host_list, self._processes_per_host, num_processes)
return command