def create_nanotron_config()

in slurm_launcher.py [0:0]


def create_nanotron_config(args) -> Config:
    """
    Create a Nanotron configuration object based on the provided arguments.

    Args:
        args: Command line arguments

    Returns:
        Nanotron Config object
    """
    # Generate model configuration
    model_config = generate_model_config(
        model_size=args.model,
        hidden_size=args.hidden_size,
        intermediate_size=args.intermediate_size,
        num_hidden_layers=args.num_layers,
        num_attention_heads=args.num_heads,
        num_key_value_heads=args.num_kv_heads,
        vocab_size=args.vocab_size,
        max_position_embeddings=args.seq,
    )

    # Calculate number of parameters for logging
    num_params = human_format(
        model_config.vocab_size * model_config.hidden_size * 2
        + model_config.num_hidden_layers
        * (
            3 * model_config.hidden_size * model_config.intermediate_size
            + 4 * model_config.hidden_size * model_config.hidden_size
        )
    ).replace(".", "p")

    print(f"Model has {num_params} parameters")

    # Use user-provided parallelism directly
    parallelism = ParallelismArgs(
        dp=args.dp,
        pp=args.pp,
        tp=args.tp,
        context_parallel_size=args.cp,
        expert_parallel_size=args.ep,
        pp_engine="1f1b",
        tp_mode="REDUCE_SCATTER",
        tp_linear_async_communication=True,
        recompute_layer=False,
    )

    # Define tokens configuration
    tokens = TokensArgs(
        sequence_length=args.seq,
        train_steps=args.steps,
        micro_batch_size=args.mbs,
        batch_accumulation_per_replica=args.acc,
    )

    # Calculate global batch size for logging
    gbs = (
        parallelism.dp
        * tokens.batch_accumulation_per_replica
        * tokens.micro_batch_size
        * tokens.sequence_length
        * parallelism.context_parallel_size
        * parallelism.expert_parallel_size
    )
    total_tokens = gbs * args.steps
    print(f"GBS: {(gbs)/1e6:.2f}M, total training tokens: {(total_tokens)/1e9:.2f}B")

    # Configure learning rate schedule
    lr_scheduler = LRSchedulerArgs(
        learning_rate=args.learning_rate,
        lr_warmup_steps=args.warmup_steps,
        lr_warmup_style="linear",
        lr_decay_style="cosine",
        min_decay_lr=args.min_lr,
    )

    # Configure optimizer
    optimizer = OptimizerArgs(
        zero_stage=args.zero,
        weight_decay=args.weight_decay,
        clip_grad=args.grad_clip,
        accumulate_grad_in_fp32=True,
        learning_rate_scheduler=lr_scheduler,
        optimizer_factory=AdamWOptimizerArgs(
            adam_eps=1e-08,
            adam_beta1=0.9,
            adam_beta2=0.95,
            torch_adam_is_fused=True,
        ),
    )

    # Configure datasets
    data_stages = [
        DatasetStageArgs(
            name="Stable Training Stage",
            start_training_step=1,
            data=DataArgs(
                # For pretraining:
                # dataset=PretrainDatasetsArgs(
                #     hf_dataset_or_datasets=args.dataset,
                #     text_column_name=args.text_column,
                # ),
                # When using a Nanoset, we need to specify the vocab size of the tokenizer used to tokenize the dataset or larger
                dataset=NanosetDatasetsArgs(
                    dataset_folder="/fsx/loubna/tokenized_for_exps/mcf-dataset",  # 1.4T tokens
                ),
                # For SFT (uncomment to use):
                # dataset=SFTDatasetsArgs(
                #     hf_dataset_or_datasets=args.dataset,
                #     hf_dataset_splits="train",
                #     debug_max_samples=1000,
                # ),
                seed=args.seed,
            ),
        ),
    ]
    # Configure checkpointing
    os.makedirs(args.checkpoints_path, exist_ok=True)
    checkpoints = CheckpointsArgs(
        checkpoints_path=os.path.join(args.checkpoints_path, args.run),
        checkpoint_interval=args.save_interval,
        save_initial_state=args.save_initial_state,
    )

    # Create the final config
    config = Config(
        general=GeneralArgs(
            project=args.project,
            run=args.run,
            seed=args.seed,
            ignore_sanity_checks=args.no_sanity,
            benchmark_csv_path=args.bench,
        ),
        checkpoints=checkpoints,
        parallelism=parallelism,
        model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config),
        tokenizer=TokenizerArgs(args.tokenizer),
        optimizer=optimizer,
        logging=LoggingArgs(log_level=args.log_lvl, log_level_replica=args.log_lvl, iteration_step_info_interval=1),
        tokens=tokens,
        data_stages=data_stages,
        profiler=ProfilerArgs(profiler_export_path=args.profiler_export_path)
        if args.profiler_export_path is not None
        else None,
    )

    return config