def main()

in method_comparison/MetaMathQA/run.py [0:0]


def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None:
    tic_total = time.perf_counter()
    start_date = dt.datetime.now(tz=dt.timezone.utc).replace(microsecond=0).isoformat()

    peft_branch = get_peft_branch()
    if peft_branch == "main":
        print_verbose("===== This experiment is categorized as a MAIN run because the PEFT branch is 'main' ======")
    else:
        print_verbose(
            f"===== This experiment is categorized as a TEST run because the PEFT branch is '{peft_branch}' ======"
        )

    # load configs
    peft_config: Optional[PeftConfig] = None
    if os.path.exists(os.path.join(path_experiment, CONFIG_NAME)):
        peft_config = PeftConfig.from_pretrained(path_experiment)
    else:
        print_verbose(f"Could not find PEFT config at {path_experiment}, performing FULL FINETUNING")
    path_train_config = os.path.join(path_experiment, FILE_NAME_TRAIN_PARAMS)
    train_config = get_train_config(path_train_config)
    set_seed(train_config.seed)

    # initialize objects
    cuda_memory_init = init_cuda()
    tokenizer = get_tokenizer(model_id=train_config.model_id, max_seq_length=train_config.max_seq_length)

    model_info = get_base_model_info(train_config.model_id)
    metamath_info = get_dataset_info("meta-math/MetaMathQA")
    gsm8k_info = get_dataset_info("openai/gsm8k")
    model = get_model(
        model_id=train_config.model_id,
        dtype=train_config.dtype,
        compile=train_config.compile,
        attn_implementation=train_config.attn_implementation,
        peft_config=peft_config,
        autocast_adapter_dtype=train_config.autocast_adapter_dtype,
    )
    print_verbose(model)

    # train model
    train_result = train(
        model=model,
        max_steps=train_config.max_steps,
        batch_size=train_config.batch_size,
        batch_size_eval=train_config.batch_size_eval,
        tokenizer=tokenizer,
        cuda_memory_init=cuda_memory_init,
        eval_steps=train_config.eval_steps,
        generation_kwargs=train_config.generation_kwargs,
        grad_norm_clip=train_config.grad_norm_clip,
        optimizer_type=train_config.optimizer_type,
        optimizer_kwargs=train_config.optimizer_kwargs,
        query_template=train_config.query_template,
        lr_scheduler_arg=train_config.lr_scheduler,
        use_amp=train_config.use_amp,
        is_adalora=isinstance(peft_config, AdaLoraConfig),
    )

    if train_result.status == TrainStatus.FAILED:
        print_verbose("Training failed, not logging results")
        sys.exit(1)

    file_size = get_file_size(
        model,
        peft_config=peft_config,
        clean=clean,
        print_fn=print_verbose,
    )

    time_total = time.perf_counter() - tic_total
    # log results: print and save to file
    log_results(
        experiment_name=experiment_name,
        train_result=train_result,
        cuda_memory_init=cuda_memory_init,
        time_total=time_total,
        file_size=file_size,
        model_info=model_info,
        datasets_info={"metamath": metamath_info, "gsm8k": gsm8k_info},
        start_date=start_date,
        train_config=train_config,
        peft_config=peft_config,
        print_fn=print_verbose,
    )