ablations/training/launch_exp.py (226 lines of code) (raw):

import os from pathlib import Path import subprocess import sys import tempfile from datetime import datetime from nanotron.logging import human_format from nanotron.models.llama import LlamaConfig from datatrove.io import get_datafolder from nanotron.config import DatasetStageArgs, NanosetDatasetsArgs, S3UploadArgs # Paths LOCAL_TMP_PATH_ON_NODE = f"/scratch/{os.environ.get('USER')}" LAUNCH_CONFIGS_PATH = f"path/to/launch-configs" # Executables NANOTRON_RUN_TRAIN_SCRIPT = f"path/to/run_train.py" S5CMD_PATH = "path/to/s5cmd" S3_CHECKPOINTS_PREFIX = "path/to/where_to_save_checkpoints" # Logging parameters LOGS_PATH = f"path/to/slurm-logs" REPO_ID = f"id of the repo to use for logging" PROJECT = "name of the project" EMAIL = "email to send notifications to" # Resources parameters NUM_GPUS = 8 NUM_CPUS_IN_NODE = 88 CPUS_PER_GPU = NUM_CPUS_IN_NODE // NUM_GPUS model_config = LlamaConfig( # Config for a 1.46B model bos_token_id=1, eos_token_id=2, hidden_act="silu", hidden_size=2048, initializer_range=0.02, intermediate_size=8192, max_position_embeddings=2048, num_attention_heads=32, num_hidden_layers=14, num_key_value_heads=32, pretraining_tp=1, rms_norm_eps=1e-05, rope_scaling=None, tie_word_embeddings=True, use_cache=True, vocab_size=256008, # gemma tokenizer + some room ) num_params = human_format( model_config.vocab_size * model_config.hidden_size + model_config.num_hidden_layers * ( 3 * model_config.hidden_size * model_config.intermediate_size + 4 * model_config.hidden_size * model_config.hidden_size ) ).replace(".", "p") print(f"Model has {num_params} parameters") def launch_slurm_job(launch_file_contents, *args): """ Small helper function to save a sbatch script and call it. Args: launch_file_contents: Contents of the sbatch script *args: any other arguments to pass to the sbatch command Returns: the id of the launched slurm job """ with tempfile.NamedTemporaryFile("w") as f: f.write(launch_file_contents) f.flush() return subprocess.check_output(["sbatch", *args, f.name]).decode("utf-8").split()[-1] if __name__ == "__main__": import argparse from dataclasses import fields, is_dataclass from nanotron.config import get_config_from_file parser = argparse.ArgumentParser() parser.add_argument("data", help="dataset folder", type=str) parser.add_argument("run_name", help="run name", type=str) parser.add_argument("language", help="language", type=str) parser.add_argument("-d", help="dependency job", type=str, default=None) parser.add_argument("--seed", help="seed", type=int, default=6) parser.add_argument("--train_steps", "-ts", help="training steps. Total_toks=seq_len*steps*micro_bs*batch_accum_per_replica*dp_size", type=int, default=14000) parser.add_argument("--priority", "--qos", "-p", help="qos to use", type=str, default="normal") args = parser.parse_args() SEED = args.seed dataset_name = run_name = args.run_name.replace(" ", "_") # Specific name for this run (checkpoints/logs/tensorboard) RUN = f"{num_params}-{dataset_name}-seed-{SEED}" df = get_datafolder(f"{S3_CHECKPOINTS_PREFIX}/{RUN}") if df.exists("latest.txt") and df.cat_file("latest.txt") == bytes(str(args.train_steps), "utf-8"): print(f"Not launching as latest checkpoint is already {args.train_steps} steps") sys.exit(0) import torch from nanotron.config import ( CheckpointsArgs, Config, DataArgs, GeneralArgs, LlamaConfig, LoggingArgs, LRSchedulerArgs, ModelArgs, OptimizerArgs, ParallelismArgs, RandomInit, TokenizerArgs, TokensArgs, AdamWOptimizerArgs, ) def print_differences(target, updates): if not is_dataclass(target) or not is_dataclass(updates): raise ValueError("Both target and updates should be dataclass instances") for field in fields(target): update_value = getattr(updates, field.name) if update_value is not None: if is_dataclass(update_value): print_differences(getattr(target, field.name), update_value) else: target_value = getattr(target, field.name) if update_value != target_value: if update_value.__class__.__module__ != "builtins": continue print(f"{field.name}: {target_value} -> {update_value}") data = [ DatasetStageArgs( name="Training Stage", start_training_step=1, data=DataArgs( seed=SEED, num_loading_workers=0, dataset=NanosetDatasetsArgs( dataset_folder=args.data if not args.data.startswith("s3://") else f"{LOCAL_TMP_PATH_ON_NODE}/dataset/{RUN}/", dataset_weights=None, ) ) ), ] general = GeneralArgs( project=PROJECT, run=RUN, ignore_sanity_checks=True, seed=SEED, ) checkpoints = CheckpointsArgs( checkpoints_path=Path(f"{LOCAL_TMP_PATH_ON_NODE}/checkpoints/{RUN}"), checkpoints_path_is_shared_file_system=False, checkpoint_interval=500, save_initial_state=True, ) parallelism = ParallelismArgs( dp=64, pp=1, tp=1, pp_engine="1f1b", tp_mode="REDUCE_SCATTER", tp_linear_async_communication=True, ) # num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1)) # parallelism.dp=int(num_nodes*8//parallelism.pp//parallelism.tp), # How many remaining GPU when taking into account PP, TP and 8 GPUs per node tokens = TokensArgs( batch_accumulation_per_replica=4, micro_batch_size=4, sequence_length=2048, train_steps=args.train_steps, val_check_interval=-1, ) model = ModelArgs( model_config=model_config, make_vocab_size_divisible_by=1, init_method=RandomInit( std=0.02 ), dtype=torch.bfloat16, ) logging = LoggingArgs( # 'debug', 'info', 'warning', 'error', 'critical' and 'passive' log_level="info", log_level_replica="info", iteration_step_info_interval=1, ) optimizer = OptimizerArgs( accumulate_grad_in_fp32=True, clip_grad=1.0, weight_decay=0.1, zero_stage=0, learning_rate_scheduler=LRSchedulerArgs( learning_rate=3e-4, lr_warmup_steps=500, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=3.0e-5 ), optimizer_factory=AdamWOptimizerArgs( adam_beta1=0.9, adam_beta2=0.95, adam_eps=1.0e-8, torch_adam_is_fused=True, ), ) tokenizer = TokenizerArgs( tokenizer_name_or_path="google/gemma-7b", ) s3_upload = S3UploadArgs( upload_s3_path=f"{S3_CHECKPOINTS_PREFIX}/{RUN}", remove_after_upload=True, s5cmd_numworkers=16, s5cmd_concurrency=5, s5cmd_path=S5CMD_PATH, ) config = Config( general=general, checkpoints=checkpoints, parallelism=parallelism, model=model, tokenizer=tokenizer, logging=logging, tokens=tokens, optimizer=optimizer, data_stages=data, profiler=None, s3_upload=s3_upload, lighteval=None, ) NODES = 8 #### DEBUG MODE if os.environ.get("DEBUG_MODE", "0") != "0": print("##### WARNING DEBUG MODE #####") config.parallelism.dp = 2 config.parallelism.pp = 2 config.parallelism.tp = 2 config.tokens.micro_batch_size = 3 config.tokens.batch_accumulation_per_replica = 2 config.checkpoints.save_initial_state = True NODES = 1 # Sanity check that we can load, save to YAML and reload the config timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") os.makedirs(f"{LAUNCH_CONFIGS_PATH}/{run_name}", exist_ok=True) config_path_yaml = f"{LAUNCH_CONFIGS_PATH}/{run_name}/{timestamp}.yaml" config.save_as_yaml(config_path_yaml) config2 = get_config_from_file(config_path_yaml, config_class=Config) print_differences(config, config2) os.makedirs(f"{LOGS_PATH}/{run_name}", exist_ok=True) dataset_download_cmd = "" if not args.data.startswith("s3://") else f"srun --ntasks-per-node=1 rm -rf {LOCAL_TMP_PATH_ON_NODE}/dataset\nsrun --ntasks-per-node=1 s5cmd cp '{args.data.removesuffix('/')}/*' {LOCAL_TMP_PATH_ON_NODE}/dataset/{RUN}/" job_name = f"{run_name}-{SEED}" sbatch_script = f"""#!/bin/bash #SBATCH --job-name={job_name} #SBATCH --nodes={NODES} #SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! #SBATCH --cpus-per-task={NUM_CPUS_IN_NODE} #SBATCH --gres=gpu:{NUM_GPUS} #SBATCH --partition=hopper-prod #SBATCH --output={LOGS_PATH}/{run_name}/train-{timestamp}-%x-%j # #SBATCH --array=1-1%1 #SBATCH --qos={args.priority} #SBATCH --begin=now+0minutes #SBATCH --mail-type=ALL #SBATCH --mail-user={EMAIL} #SBATCH --requeue {"#SBATCH --dependency=afterok:" + args.d if args.d else ""} ########################################### # [BEGINING] ADAPT TO YOUR ENVIRONMENT # [END] ADAPT TO YOUR ENVIRONMENT ########################################### set -x -e ##### TO UPDATE ##### ##### END TO UPDATE ###### echo "START TIME: $(date)" secs_to_human(){{ echo "$(( ${{1}} / 3600 )):$(( (${{1}} / 60) % 60 )):$(( ${{1}} % 60 ))" }} start=$(date +%s) echo "$(date -d @${{start}} "+%Y-%m-%d %H:%M:%S"): ${{SLURM_JOB_NAME}} start id=${{SLURM_JOB_ID}}\n" {dataset_download_cmd} # SLURM stuff export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) export MASTER_PORT=$((1024 + RANDOM % 64511)) export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` export TMPDIR={LOCAL_TMP_PATH_ON_NODE} export CUDA_DEVICE_MAX_CONNECTIONS="1" module load cuda/12.1 echo go $COUNT_NODE echo $HOSTNAMES ##### MOVE TO YAML ###### CMD=" \ {NANOTRON_RUN_TRAIN_SCRIPT} \ --config-file {config_path_yaml} " export LAUNCHER="python -u -m torch.distributed.run \ --nproc_per_node {NUM_GPUS} \ --nnodes $COUNT_NODE \ --rdzv-backend c10d \ --rdzv-endpoint $MASTER_ADDR:$MASTER_PORT \ --rdzv-id $SLURM_JOB_ID \ --node_rank $SLURM_PROCID \ --role $SLURMD_NODENAME: \ --max_restarts 0 \ --tee 3 \ " # Wait a random number between 0 and 1000 (milliseconds) to avoid too many concurrent requests to the hub random_milliseconds=$(( RANDOM % 1001 )) sleep_time=$(bc <<< "scale=3; $random_milliseconds / 1000") echo "Sleeping for $sleep_time seconds..." sleep $sleep_time launch_args="srun $SRUN_ARGS -u bash -c $LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" srun $SRUN_ARGS -u bash -c "$LAUNCHER --node_rank $SLURM_PROCID --role $SLURMD_NODENAME: $CMD" echo "END TIME: $(date)" { "" if not args.data.startswith("s3://") else f"srun --ntasks-per-node=1 rm -rf {LOCAL_TMP_PATH_ON_NODE}/dataset/{RUN}/" } """ id = launch_slurm_job(sbatch_script) log_path = f"{LOGS_PATH}/{run_name}/train-{timestamp}-{job_name}-{id}" print(f"Launched with Slurm job id={id}") print(f"To view the logs, use the command: tail -f {log_path}")