in launcher/nemo/stages.py [0:0]
def _make_train_script_text(self, stage_cfg_path=None, port=41000) -> str:
"""
The custom train entry script, it will be responsible for following
- Handle resolving hostname and create torch distribtued args
- Pull from github if required
- Launch torchrun command
"""
nodes = get_num_nodes(self.stage_cfg)
ntasks_per_node = get_ntasks_per_node(self.stage_cfg)
script_text = ["#!/bin/bash", "set -ex"]
# Also export env vars here so that they can be consumed by docker container
env_vars = self.get_env_vars()
if env_vars:
script_text.extend([f"export {k}={v}" for k, v in env_vars.items()])
# Prepare for the host information to create the torchrun command
if nodes > 1:
script_text.extend(
[
f"{MASTER_ADDR}=$(head -n 1 {str(self._get_hostfile_location())})",
f'{NODEID}=$(($(grep -nx -o "\\b$(hostname)\\b" {str(self._get_hostfile_location())} | cut -d ":" -f 1) - 1))',
f"{NNODES}={nodes}",
f"{PROCESSES_PER_NODE}={ntasks_per_node}",
f"{MASTER_PORT}={port}",
"",
]
)
if self.device == "trainium":
script_text.append(
f'{DISTRIBUTED_ARGS}="--nproc_per_node ${PROCESSES_PER_NODE} --nnodes ${NNODES} --node_rank ${NODEID} --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT}"'
)
else:
script_text.append(
f'{DISTRIBUTED_ARGS}="--nproc_per_node ${PROCESSES_PER_NODE} --nnodes ${NNODES} --rdzv_endpoint=${MASTER_ADDR} --rdzv_id=100 --rdzv_backend=c10d"'
)
else:
script_text.append(f'{DISTRIBUTED_ARGS}="--nproc_per_node {ntasks_per_node}"')
# Prepare github pull
# Aligns with the train-script preparation in launcher/nemo/k8s_templates/training.yaml
script_text.append("")
if self.cfg.get("git", None) is not None or self._default_repo is not None:
repo_url_or_path = self._default_repo
branch = self._default_branch
if self.cfg.get("git", None) is not None:
if self.cfg.git.get("repo_url_or_path", None) is not None:
repo_url_or_path = str(self.cfg.git.get("repo_url_or_path"))
assert repo_url_or_path is not None, "`repo_url_or_path` must be defined when setting git config"
if self.cfg.git.get("token", None) is not None:
repo_url_or_path = self.insert_git_token(repo_url_or_path, self.cfg.git.token)
if self.cfg.git.get("branch", None) is not None:
branch = self.cfg.git.branch
if not self._use_local_repo():
# Remote repo, clone the repo url
script_text.extend(
[
"# For greater env stability, grab hostname from `hostname`",
"# https://sim.amazon.com/issues/P162624109",
'LAUNCHER_HOSTNAME="$(hostname)"',
"",
"mkdir -p $HOME/tmp",
'GIT_CLONE_DIR="$HOME/tmp/$LAUNCHER_HOSTNAME"',
"[[ -d $GIT_CLONE_DIR ]] && rm -rf $GIT_CLONE_DIR",
f"git clone {repo_url_or_path} $GIT_CLONE_DIR",
"GIT_CLONE_DIR=${GIT_CLONE_DIR}/",
"cd $GIT_CLONE_DIR",
# cache can lead to unexpected behavior when user clones
# the Adapter and modifies it
"rm -rf __pycache__",
]
)
else:
# simply cd to the directory for local repo
script_text.append(f"cd {repo_url_or_path}")
if branch is not None:
script_text.append(f"git checkout {branch}")
if self.cfg.get("git", None) is not None and self.cfg.git.get("commit", None) is not None:
script_text.append(f"git fetch origin {self.cfg.git.commit}")
script_text.append(f"git reset --hard {self.cfg.git.commit}")
if OmegaConf.select(self.cfg, "git.update_adapter", default=False):
script_text.append("\npip install . --force-reinstall --no-deps")
else:
script_text.append('GIT_CLONE_DIR=""')
if not OmegaConf.select(self.cfg, "training.run.model_type", default="").startswith("neuron"):
script_text.append("")
script_text.append("unset SLURM_NTASKS")
if get_container_type(self.cfg.get("container", None)) == "enroot" and self.cluster == "bcm":
if OmegaConf.select(self.cfg, "recipes.model.multi_modal", default=False):
transformers_upgrade_cmd = "pip install transformers==4.45.2"
script_text.append("")
script_text.append(transformers_upgrade_cmd)
if OmegaConf.select(self.cfg, "recipes.model.model_type", default=False) == "deepseek_r1":
transformers_upgrade_cmd = "pip install transformers==4.48.2"
script_text.append("")
script_text.append(transformers_upgrade_cmd)
if OmegaConf.select(self.cfg, "recipes.model.model_type", default=None) == "llama_v4":
transformers_upgrade_cmd = "pip install transformers==4.51.3"
script_text.append("")
script_text.append(transformers_upgrade_cmd)
script_text.append("")
script_text.append(self._make_custom_call_string(stage_cfg_path))
return "\n".join(script_text)