def main()

in scripts/scaling_benchmarks.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="Run scaling benchmarks with different parallelism configurations")
    parser.add_argument(
        "--configs-dir",
        type=str,
        default="benchmark/configs",
        help="Directory to store generated configs",
    )
    parser.add_argument(
        "--scripts-dir",
        type=str,
        default="benchmark/scripts",
        help="Directory to store generated SLURM scripts",
    )
    parser.add_argument("--partition", type=str, default="hopper-prod", help="SLURM partition to use")
    parser.add_argument("--time", type=str, default="00:40:00", help="Time limit for each job")
    parser.add_argument(
        "--base-config",
        type=str,
        default="examples/config_tiny_llama_bench.yaml",
        help="Base configuration file to use",
    )
    parser.add_argument(
        "--base-script",
        type=str,
        default="run_multinode.sh",
        help="Base SLURM script to use",
    )
    parser.add_argument(
        "--pending-csv",
        type=str,
        default="benchmark/results/pending_experiments2.csv",
        help="CSV file to store pending experiments",
    )
    parser.add_argument(
        "--benchmark-csv",
        type=str,
        default="benchmark/results/bench_final2.csv",
        help="CSV file to store benchmark results",
    )
    parser.add_argument(
        "--run",
        action="store_true",
        help="Automatically submit all generated SLURM scripts",
    )
    parser.add_argument("--debug", action="store_true", help="Debug mode")
    parser.add_argument(
        "--limit",
        type=str,
        default=None,
        help="Limit the number of configurations to run (e.g. 100:200)",
    )
    parser.add_argument("--profile", action="store_true", help="Enable profiling")
    parser.add_argument("--use-bash", action="store_true", help="Use bash instead of sbatch")
    args = parser.parse_args()

    # Parse limit argument if provided
    if args.limit is not None:
        if ":" in args.limit:
            start, end = args.limit.split(":")
            start = int(start) if start else None
            end = int(end) if end else None
            args.limit = slice(start, end)
        else:
            args.limit = slice(int(args.limit))

    # Validate input files exist
    if not os.path.exists(args.base_config):
        raise FileNotFoundError(f"Base config file not found: {args.base_config}")
    if not os.path.exists(args.base_script):
        raise FileNotFoundError(f"Base script file not found: {args.base_script}")

    # Create directories if they don't exist
    for directory in [args.configs_dir, args.scripts_dir]:
        os.makedirs(directory, exist_ok=True)

    # Define model configurations
    model_configs = {
        # (layers, hidden_size, heads, intermediate_size)
        # "1B": (16, 2048, 32, 8192),  # 1.2G
        "3B": (28, 3072, 32, 8192),  # 3.57G  24heads -> 32heads
        # "4B": (30, 3072, 32, 8192),  # 30 layers distributed among PP-2
        # "8B": (32, 4096, 32, 14336),  # 8.0G
        # "70B": (80, 8192, 64, 28672),  # 70G
        # "405B": (126, 16384, 128, 53248),  # 406G
    }

    # Define configurations to test
    configurations = []

    # For each model size, test different GPU configurations
    # for model_name, (num_layers, hidden_size, num_heads, intermediate_size) in model_configs.items():
    #     vocab_size = 32768
    #     zero_stage = 0
    #     tp_mode = "REDUCE_SCATTER"
    #     configs = [  # 64 nodes max
    #         # 2k, 8k, 32k
    #         # GBS: 1M, 4M
    #         # Format: (dp, tp, pp, batch_accum, seq_len, mbs, ...)
    #         # Using SP what's the biggest seqlen we can fit?
    #         # (1, 8, 1, 1, 2048, 1, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # (1, 8, 1, 1, 2048, 2, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # (1, 8, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # (1, 8, 1, 1, 2048, 32, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # best run
    #         # (1, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 0, tp_mode),

    #         # test zero
    #         # (3, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 0, tp_mode),
    #         # (3, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 1, tp_mode),
    #         # (24, 1, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 0, tp_mode),
    #         # (24, 1, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 1, tp_mode),
    #         # test tp mode
    #         # (1, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, "ALL_REDUCE"),
    #         # test pp
    #         # (1, 1, 8, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # (1, 8, 2, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # (1, 1, 8, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # (1, 2, 8, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # (1, 2, 64, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #         # (1, 2, 16, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode),
    #     ]
    #     configurations.extend(configs)


    # TP scaling tests with corresponding max batch sizes
    # tp_mbs_configs = [
    #     # Format: (tp, mbs)
    #     # TP=1 OOMs
    #     (2, 3),    # 363.66 TFLOPs, 14167.25 tok/s/gpu  
    #     (4, 9),   # 345.51 TFLOPs, 13460.16 tok/s/gpu (-5%)
    #     (8, 18),   # 279.50 TFLOPs, 10888.53 tok/s/gpu (-19%)
    #     (16, 40),  # 158.10 TFLOPs, 6159.30 tok/s/gpu (-43%)
    #     (32, 90), # 92.66 TFLOPs, 3609.73 tok/s/gpu (-41%)
    # ]
    # TP_, MBS_ = tp_mbs_configs[4]


    # Method 2: Parameter combinations
    PARALLEL_CONFIGS = [(1, 2, 1)
                        for dp in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
                        for tp in [1, 2, 4, 8, 16, 32]
                        for pp in [2]]
    # Sort PARALLEL_CONFIGS by total GPU count (dp*tp*pp) ascending
    PARALLEL_CONFIGS = sorted(PARALLEL_CONFIGS, key=lambda x: x[0] * x[1] * x[2])
    SEQUENCE_LENGTHS = [4096]
    MBS = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]  # ~1M, 4M
    MBS = [2]  # ~1M, 4M
    GRAD_ACCUM_STEPS = [1]  # ~1M, 4M
    VOCAB_SIZES = [131072] # 49152 131072
    ZERO_STAGES = [0] # 0 if dp>=32 and model<80 / if no need for memory
    TP_MODES = ["REDUCE_SCATTER"]
    # TP_MODES = ["ALL_REDUCE"]
    GBS = [512 * 2048]  # 1M
    MIN_NODES = 0
    MAX_NODES = 10

    time = 0
    TIME_PER_CONFIG = 2  # 2 minutes per config
    counter = 0
    configurations = []
    for pp, tp, dp in PARALLEL_CONFIGS:
        for model_name, (num_layers, hidden_size, num_heads, intermediate_size) in model_configs.items():
            for seq_len in SEQUENCE_LENGTHS:
                for mbs in MBS:
                    for batch_accum in GRAD_ACCUM_STEPS:
                        for vocab_size in VOCAB_SIZES:
                            for zero_stage in ZERO_STAGES:
                                for tp_mode in TP_MODES:
                                    # batch_accum = pp-1

                                    # Optional: Add conditions to filter out unwanted combinations
                                    total_gpus = dp * tp * pp
                                    if not MIN_NODES <= total_gpus / 8 <= MAX_NODES:
                                        print(f"Skipping config - nodes {total_gpus/8} not in range [{MIN_NODES}, {MAX_NODES}]")
                                        continue

                                    tokens_per_step = dp * mbs * batch_accum * seq_len
                                    # if tokens_per_step not in GBS:
                                    #     continue
                                    # if batch_accum > 1:
                                    #     print(f"Skipping config - batch_accum {batch_accum} > 1")
                                    #     continue

                                    # if dp=1 skip zero stage 1
                                    if dp == 1 and zero_stage == 1:
                                        print(f"Skipping config - dp=1 with zero stage 1")
                                        continue

                                    # if tp=1 skip tp_mode=ALL_REDUCE
                                    # if tp == 1 and tp_mode == "ALL_REDUCE":
                                    #     print(f"Skipping config - tp=1 with ALL_REDUCE")
                                    #     continue

                                    if batch_accum < pp - 1:
                                        print(f"Skipping config - batch_accum {batch_accum} < pp-1 ({pp-1})")
                                        continue

                                    if model_name == "1B" and pp > 21: # too many pp for numlayers
                                        print(f"Skipping config - 1B model with pp {pp} > 21")
                                        continue

                                    config = (dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode)
                                    if config not in configurations:
                                        counter += 1
                                        time += total_gpus * TIME_PER_CONFIG / 8 / MAX_NODES  # 2 minutes per config
                                        configurations.append(config)

    # print(f"experiments: {counter}")
    # print(f"time (days): {time/60/24} | {time/60:.2f} hours")
    # print(len(configurations))

    # # Load configs from pickle file
    # import pickle
    # with open('configs.pkl', 'rb') as f:
    #     configurations = pickle.load(f)

    # validate configs
    new_configs = []
    for config in configurations:
        # config = (dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode)
        dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode = config
        tokens_per_step = dp * mbs * batch_accum * seq_len
        # if tokens_per_step not in GBS:
        #     print(f"Invalid config: {config} | tokens_per_step: {tokens_per_step}")
            # continue
        if dp == 1 and zero_stage == 1:
            print(f"Invalid config: {config} | dp: {dp} | zero_stage: {zero_stage}")
            continue
        # if tp == 1 and tp_mode == "ALL_REDUCE":
        #     print(f"Invalid config: {config} | tp: {tp} | tp_mode: {tp_mode}")
        #     continue
        if batch_accum < pp - 1:
            print(f"Invalid config: {config} | batch_accum: {batch_accum} | pp: {pp}")
            continue
        new_configs.append(config)
    configurations = new_configs

    print(len(configurations))

    if args.debug:
        print("Debug mode: only running 1 configuration")
        configurations = configurations[:1]

    if isinstance(args.limit, slice):
        print(f"Limiting to {args.limit} configurations")
        configurations = configurations[args.limit]
    elif isinstance(args.limit, int):
        print(f"Limiting to {args.limit} configurations")
        configurations = configurations[: args.limit]

    # run first 100 configurations
    # configurations = configurations[:120+5000]

        
    # load data
    import pandas as pd
    old_results_df = pd.read_csv('benchmark/results/bench_final2_mfu2.csv')
    old_results_df = old_results_df[old_results_df['status'].isin(['Success', 'OOM'])]

    # Generate configs and scripts
    generated_scripts = []
    configs = []
    for dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode in tqdm(configurations, desc="Generating configs and scripts"):
        try:
            # Create config
            config = create_config(
                dp=dp,
                tp=tp,
                pp=pp,
                batch_accum=batch_accum,
                seq_len=seq_len,
                micro_batch_size=mbs,
                base_config_path=args.base_config,
                num_layers=num_layers,
                hidden_size=hidden_size,
                num_attention_heads=num_heads,
                intermediate_size=intermediate_size,
                vocab_size=vocab_size,
                zero_stage=zero_stage,
                tp_mode=tp_mode,
                profile=args.profile,
                benchmark_csv_path=args.benchmark_csv,
            )

            if config['general']['run'] in old_results_df['name'].values:
                #job_id < 14097150
                if pp==1:
                    print(f"Skipping {config['general']['run']} because it already exists in old_results_df")
                    continue
                elif int(old_results_df[old_results_df['name']==config['general']['run']]['job_id'].values[0]) >= 14097150:
                    print(f"Skipping {config['general']['run']} because it already exists in old_results_df")
                    continue

            # Save config
            config_path = os.path.join(args.configs_dir, f"config_{config['general']['run']}.yaml")
            with open(config_path, "w") as f:
                yaml.dump(config, f, default_flow_style=False)

            # Generate and save SLURM script
            script = generate_slurm_script(config, dp, tp, pp, time=args.time, partition=args.partition, base_script_path=args.base_script, use_bash=args.use_bash)

            script_path = os.path.join(args.scripts_dir, f"run_{config['general']['run']}.sh")
            with open(script_path, "w") as f:
                f.write(script)

            # Make script executable
            os.chmod(script_path, 0o755)

            generated_scripts.append(script_path)
            configs.append(config)

        except Exception as e:
            print(f"Error processing configuration (dp={dp}, tp={tp}, pp={pp}): {str(e)}")

    # Submit jobs if requested
    job_ids = []
    if args.run:
        import subprocess

        print("\nSubmitting jobs...")
        for script_path, config in tqdm(zip(generated_scripts, configs), desc="Submitting jobs"):
            try:
                if args.use_bash:
                    env = os.environ.copy()
                    salloc_jobid = os.environ.get("SALLOC_JOBID")
                    if not salloc_jobid:
                        raise ValueError("SALLOC_JOBID environment variable is required but not set. Please define it in your environment.")
                    env["SALLOC_JOBID"] = os.environ.get("SALLOC_JOBID")
                    env["NNODES"] = str(config["parallelism"]["dp"] * config["parallelism"]["tp"] * config["parallelism"]["pp"] // 8)
                    result = subprocess.run(["bash", script_path], check=True, env=env)
                    job_id = None  # No job ID for bash execution
                    print(f"bash {script_path}")
                else:
                    result = subprocess.run(["sbatch", script_path], check=True, capture_output=True, text=True)
                    # Extract job ID from sbatch output (format: "Submitted batch job 123456")
                    job_id = result.stdout.strip().split()[-1]
                    print(f"sbatch {script_path}: {result.stdout.strip()}")
                job_ids.append(job_id)
            except subprocess.CalledProcessError as e:
                print(f"Error {'running' if args.use_bash else 'submitting'} {script_path}: {e.stderr}")
                job_ids.append(None)

        # Save configs with job IDs
        save_experiment_configs(configs, args.pending_csv, job_ids=job_ids)

    else:
        print("\nTo run individual jobs:")
        for script_path in generated_scripts:
            print(f"sbatch {script_path}")
            job_ids.append(None)