def _create_command()

in src/sagemaker_training/smdataparallel.py [0:0]


    def _create_command(self):
        """Create mpi-based smddprun command.

        Based on the number of hosts, smddprun command differs.
        Single-node: SMDATAPARALLEL_USE_SINGLENODE flag set to 1
        Multi-node: SMDATAPARALLEL_USE_HOMOGENEOUS flag set to 1
        """
        host_list = self._hosts
        num_hosts = len(self._hosts)
        num_processes = self._processes_per_host * num_hosts

        logger.info("Network interface name: %s" % self._network_interface_name)
        logger.info("Host: %s" % self._hosts)
        if num_hosts > 1:
            # multi-node; use homogeneous
            # homogeneous mode uses 16 processes per host; 8 server; 8 worker
            smdataparallel_server_addr = self._master_hostname
            smdataparallel_server_port = 7592
            host_list = ["{}:{}".format(host, self._processes_per_host) for host in self._hosts]
            smdataparallel_flag = "SMDATAPARALLEL_USE_HOMOGENEOUS=1"
            command = self._get_mpirun_command(
                num_hosts,
                host_list,
                smdataparallel_flag,
                num_processes,
                smdataparallel_server_addr,
                smdataparallel_server_port,
            )
        else:
            # single-node
            smdataparallel_flag = "SMDATAPARALLEL_USE_SINGLENODE=1"
            command = self._get_mpirun_command(
                num_hosts, host_list, smdataparallel_flag, num_processes
            )

        msg = "Env Hosts: %s Hosts: %s process_per_hosts: %s num_processes: %s"
        logger.info(msg, self._hosts, host_list, self._processes_per_host, num_processes)

        return command