def prepare_simple_launcher_cmd_env()

in src/accelerate/utils/launch.py [0:0]


def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> tuple[list[str], dict[str, str]]:
    """
    Prepares and returns the command list and an environment with the correct simple launcher environment variables.
    """
    cmd = []
    if args.no_python and args.module:
        raise ValueError("--module and --no_python cannot be used together")

    num_processes = getattr(args, "num_processes", None)
    num_machines = args.num_machines
    if args.mpirun_hostfile is not None:
        mpi_app_name, hostfile_arg, num_proc_arg, proc_per_node_arg, bind_to_arg = _get_mpirun_args()
        bind_to = getattr(args, "bind-to", "socket")
        nproc_per_node = str(num_processes // num_machines) if num_processes and num_machines else "1"
        cmd += [
            mpi_app_name,
            hostfile_arg,
            args.mpirun_hostfile,
            proc_per_node_arg,
            nproc_per_node,
        ]
        if num_processes:
            cmd += [num_proc_arg, str(num_processes)]
        if bind_to_arg:
            cmd += [bind_to_arg, bind_to]
    if not args.no_python:
        cmd.append(sys.executable)
        if args.module:
            cmd.append("-m")
    cmd.append(args.training_script)
    cmd.extend(args.training_script_args)

    current_env = os.environ.copy()
    current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu)
    if args.debug:
        current_env["ACCELERATE_DEBUG_MODE"] = "true"
    if args.gpu_ids != "all" and args.gpu_ids is not None:
        if is_xpu_available():
            current_env["ZE_AFFINITY_MASK"] = args.gpu_ids
        elif is_mlu_available():
            current_env["MLU_VISIBLE_DEVICES"] = args.gpu_ids
        elif is_sdaa_available():
            current_env["SDAA_VISIBLE_DEVICES"] = args.gpu_ids
        elif is_musa_available():
            current_env["MUSA_VISIBLE_DEVICES"] = args.gpu_ids
        elif is_npu_available():
            current_env["ASCEND_RT_VISIBLE_DEVICES"] = args.gpu_ids
        elif is_hpu_available():
            current_env["HABANA_VISIBLE_MODULES"] = args.gpu_ids
        else:
            current_env["CUDA_VISIBLE_DEVICES"] = args.gpu_ids
    if num_machines > 1:
        assert args.main_process_ip is not None, (
            "When using multiple machines, you need to specify the main process IP."
        )
        assert args.main_process_port is not None, (
            "When using multiple machines, you need to specify the main process port."
        )

    ccl_worker_count = getattr(args, "mpirun_ccl", 0) if is_ccl_available() else 0
    if (num_processes is not None and num_processes > 1) or num_machines > 1:
        current_env["MASTER_ADDR"] = args.main_process_ip if args.main_process_ip is not None else "127.0.0.1"
        current_env["MASTER_PORT"] = str(args.main_process_port) if args.main_process_port is not None else "29500"
        current_env["CCL_WORKER_COUNT"] = str(ccl_worker_count)
    if current_env["ACCELERATE_USE_CPU"]:
        current_env["KMP_AFFINITY"] = "granularity=fine,compact,1,0"
        current_env["KMP_BLOCKTIME"] = str(1)

    try:
        mixed_precision = PrecisionType(args.mixed_precision.lower())
    except ValueError:
        raise ValueError(
            f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}."
        )

    current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision)
    if args.mixed_precision.lower() == "fp8":
        if not is_fp8_available():
            raise RuntimeError(
                "FP8 is not available on this machine. Please ensure that either Transformer Engine, MSAMP or torchao is installed."
            )
        current_env = setup_fp8_env(args, current_env)

    try:
        dynamo_backend = DynamoBackend(args.dynamo_backend.upper())
    except ValueError:
        raise ValueError(
            f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}."
        )
    current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value
    current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode
    current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph)
    current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic)
    current_env["ACCELERATE_DYNAMO_USE_REGIONAL_COMPILATION"] = str(args.dynamo_use_regional_compilation)

    current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process)
    if is_ipex_available():
        current_env["ACCELERATE_USE_IPEX"] = str(args.ipex).lower()
    if args.enable_cpu_affinity:
        current_env["ACCELERATE_CPU_AFFINITY"] = "1"
    return cmd, current_env