def parse_args()

in slurm_launcher.py [0:0]


def parse_args():
    """Parse command line arguments for the Slurm launcher."""
    parser = argparse.ArgumentParser(
        description="Nanotron Slurm Launcher", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # Required arguments
    parser.add_argument("--run", type=str, default="nanotron", help="Name for this experiment run")

    # Slurm job configuration
    slurm_group = parser.add_argument_group("Slurm Configuration")
    slurm_group.add_argument("--gpus_per_node", type=int, default=8, help="Number of GPUs per node")
    slurm_group.add_argument("--partition", type=str, default="hopper-prod", help="Slurm partition to use")
    slurm_group.add_argument("--qos", type=str, default="normal", help="Slurm QOS to use")
    slurm_group.add_argument("--time_limit", type=str, default=None, help="Time limit for the job (HH:MM:SS)")
    slurm_group.add_argument("--email", type=str, default=None, help="Email for job notifications")
    slurm_group.add_argument("--tmp_dir", type=str, default="/tmp", help="Temporary directory on compute nodes")
    slurm_group.add_argument("--pre_launch_commands", type=str, default="", help="Commands to run before job launch")
    slurm_group.add_argument("--extra_env", type=str, default="", help="Additional environment variables")
    slurm_group.add_argument("--bench", type=str, default="", help="Benchmark csv path")

    # Config file
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Path to the Nanotron config file. If not provided, a config will be created automatically.",
    )

    # Model configuration
    model_group = parser.add_argument_group("Model Configuration")
    model_group.add_argument(
        "--model",
        type=str,
        default="custom",
        choices=MODEL_SIZES.keys(),
        help="Predefined model size",
    )
    model_group.add_argument("--hidden-size", type=int, default=None, help="Hidden size (overrides model)")
    model_group.add_argument("--intermediate-size", type=int, default=None, help="Intermediate size (overrides model)")
    model_group.add_argument("--num-layers", type=int, default=None, help="Number of layers (overrides model)")
    model_group.add_argument("--num-heads", type=int, default=None, help="Number of attention heads (overrides model)")
    model_group.add_argument("--num-kv-heads", type=int, default=None, help="Number of KV heads (overrides model)")
    model_group.add_argument("--vocab-size", type=int, default=65536, help="Vocabulary size (overrides model)")
    model_group.add_argument("--seq", type=int, default=4096, help="Maximum sequence length")

    # Training configuration
    training_group = parser.add_argument_group("Training Configuration")
    training_group.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
    training_group.add_argument("--steps", type=int, default=10000, help="Number of training steps")
    training_group.add_argument("--mbs", type=int, default=2, help="Micro batch size")
    training_group.add_argument("--acc", type=int, default=8, help="Gradient accumulation steps")
    training_group.add_argument("--learning-rate", type=float, default=3e-4, help="Peak learning rate")
    training_group.add_argument("--min-lr", type=float, default=3e-5, help="Minimum learning rate for decay")
    training_group.add_argument("--weight-decay", type=float, default=0.01, help="Weight decay")
    training_group.add_argument("--grad-clip", type=float, default=1.0, help="Gradient clipping")
    training_group.add_argument("--warmup-steps", type=int, default=1000, help="Learning rate warmup steps")

    # Parallelism strategy
    parallel_group = parser.add_argument_group("Parallelism Configuration")
    parallel_group.add_argument("--dp", type=int, default=8, help="Data parallelism (DP) degree")
    parallel_group.add_argument("--pp", type=int, default=1, help="Pipeline parallelism (PP) degree")
    parallel_group.add_argument("--tp", type=int, default=2, help="Tensor parallelism (TP) degree")
    parallel_group.add_argument("--cp", type=int, default=1, help="Context parallelism degree")
    parallel_group.add_argument("--ep", type=int, default=1, help="Expert parallelism degree")
    parallel_group.add_argument("--zero", type=int, default=0, choices=[0, 1], help="ZeRO stage")

    # Dataset configuration
    data_group = parser.add_argument_group("Dataset Configuration")
    data_group.add_argument("--dataset", type=str, default=None, help="Hugging Face dataset name or path")
    data_group.add_argument("--text-column", type=str, default="text", help="Column name for text in the dataset")
    data_group.add_argument(
        "--tokenizer", type=str, default="robot-test/dummy-tokenizer-wordlevel", help="Tokenizer name or path"
    )

    # File paths
    paths_group = parser.add_argument_group("File Paths")
    paths_group.add_argument("--project", type=str, default="nanotron", help="Project name for logging")
    paths_group.add_argument(
        "--configs-path", type=str, default=DEFAULT_CONFIGS_PATH, help="Directory to save configuration files"
    )
    paths_group.add_argument(
        "--slurm-logs-path", type=str, default=DEFAULT_SLURM_LOGS_PATH, help="Directory for Slurm output logs"
    )
    paths_group.add_argument(
        "--checkpoints-path",
        type=str,
        default=DEFAULT_CHECKPOINTS_PATH,
        help="Base directory for saving model checkpoints",
    )
    slurm_group.add_argument(
        "--run-train-script",
        type=str,
        default=DEFAULT_RUN_TRAIN_SCRIPT,
        help="Path to the training script (default: run_train.py)",
    )
    slurm_group.add_argument(
        "--slurm-scripts-dir",
        type=str,
        default=DEFAULT_SLURM_SCRIPTS_DIR,
        help="Directory to save generated Slurm scripts (set to None to disable)",
    )
    paths_group.add_argument(
        "--save-interval", type=int, default=1000, help="Interval for saving checkpoints (in steps)"
    )
    paths_group.add_argument("--save-initial-state", action="store_true", help="Save initial state")

    # Logging configuration
    logging_group = parser.add_argument_group("Logging Configuration")
    logging_group.add_argument("--enable-wandb", action="store_true", help="Enable logging to Weights & Biases")
    logging_group.add_argument(
        "--profiler_export_path",
        type=str,
        default=None,
        help="Path to export the profiler tensorboard data. Use `tensorboard --logdir <path>` to view.",
    )
    logging_group.add_argument("--log-lvl", type=str, default="info", help="Log level")
    logging_group.add_argument("--no-sanity", action="store_true", help="Ignore sanity checks")

    # Execution control
    parser.add_argument("--dry-run", action="store_true", help="Generate configs but don't submit job")
    parser.add_argument("--show-logs", action="store_true", help="Show output of the job as it runs")
    return parser.parse_args()