def _create_command()

in src/sagemaker_training/torch_distributed.py [0:0]


    def _create_command(self):
        """
        Based on the number of hosts, torchrun command differs.
        Currently the elasticity feture of torchrun is not yet supported.
        """
        self._setup()
        entrypoint_type = _entry_point_type.get(environment.code_dir, self._user_entry_point)

        if entrypoint_type is _entry_point_type.PYTHON_PACKAGE:
            raise errors.ClientError(
                "Python packages are not supported for torch_distributed. "
                "Please use a python script as the entry-point"
            )

        if entrypoint_type is _entry_point_type.PYTHON_PROGRAM:
            num_hosts = len(self._hosts)
            torchrun_cmd = []

            # Adding support for neuron_parallel_compile to precompile XLA graphs,
            # if environment variable RUN_NEURON_PARALLEL_COMPILE == "1"
            # This is an example of the command line output when this flag is set:
            # "neuron_parallel_compile torchrun --nnodes 2 --nproc_per_node 32
            # --master_addr algo-1 --master_port 7777 --node_rank 0 trn_train.py
            # --max_steps 100"

            if os.environ.get("RUN_NEURON_PARALLEL_COMPILE") == "1":
                torchrun_cmd.append("neuron_parallel_compile")

            node_options = [
                TORCH_DISTRIBUTED_MODULE,
                "--nnodes",
                str(num_hosts),
                "--nproc_per_node",
                str(self._processes_per_host),
            ]

            torchrun_cmd += node_options

            multinode_options = [
                "--master_addr",
                str(self._master_hostname),
                "--master_port",
                MASTER_PORT,
                "--node_rank",
                str(self._hosts.index(self._current_host)),
            ]

            if num_hosts > 1:
                torchrun_cmd += multinode_options

            # match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)

            torchrun_cmd.append(str(self._user_entry_point))
            torchrun_cmd += self._args
            return torchrun_cmd
        else:
            raise errors.ClientError("Unsupported entry point type for torch_distributed")