in slurm_launcher.py [0:0]
def main():
"""Main entry point for the Slurm launcher."""
args = parse_args()
# Create directories if they don't exist
os.makedirs(args.configs_path, exist_ok=True)
os.makedirs(args.slurm_logs_path, exist_ok=True)
# Create Nanotron config if not provided
if args.config is None:
config = create_nanotron_config(args)
dp, pp, tp, cp, ep = (args.dp, args.pp, args.tp, args.cp, args.ep)
else:
print(f"🔍 Loading config from {args.config}")
config = Config.load_from_yaml(args.config)
dp = config.parallelism.dp
pp = config.parallelism.pp
tp = config.parallelism.tp
cp = config.parallelism.context_parallel_size
ep = config.parallelism.expert_parallel_size
# bench
if args.bench:
config.general.benchmark_csv_path = args.bench
# Save config to YAML file
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_name = args.run.replace(" ", "_")
config_dir = os.path.join(args.configs_path, run_name)
os.makedirs(config_dir, exist_ok=True)
config_path = os.path.join(config_dir, f"{timestamp}-{run_name}.yaml")
config.save_as_yaml(config_path)
print(f"💾 Config saved to {config_path}")
config.print_config_details()
# Create Slurm script
slurm_script = create_slurm_script(config_path, args, dp, pp, tp, cp, ep, args.run_train_script)
# Save Slurm script if requested
if args.slurm_scripts_dir is not None:
os.makedirs(args.slurm_scripts_dir, exist_ok=True)
slurm_script_path = os.path.join(args.slurm_scripts_dir, f"{timestamp}-{run_name}.sh")
with open(slurm_script_path, "w") as f:
f.write(slurm_script)
print(f"💾 Slurm script saved to {slurm_script_path}")
# Either submit the job or just print the script (dry run)
if args.dry_run:
print("DRY RUN - Job script:")
print(slurm_script)
print(f"🔍 Would submit job with config from {config_path}")
else:
job_id = launch_slurm_job(slurm_script)
print(f"🚀 Slurm job submitted with JOBID: {job_id}")
print(
f"🔍 Logs will be available at: {os.path.join(args.slurm_logs_path, run_name, f'{timestamp}-{run_name}-{job_id}.out')}"
)
# Tail output file when available
if args.show_logs:
tail_output_file(os.path.join(args.slurm_logs_path, run_name, f"{timestamp}-{run_name}-{job_id}.out"))
return 0