def main()

in phi3/src_train/train_mlflow.py [0:0]


def main(args):

    ###################
    # Hyper-parameters
    ###################
    # Only overwrite environ if wandb param passed
    if len(args.wandb_project) > 0:
        os.environ['WANDB_API_KEY'] = args.wandb_api_key    
        os.environ["WANDB_PROJECT"] = args.wandb_project
    if len(args.wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = args.wandb_watch
    if len(args.wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = args.wandb_log_model

    use_wandb = len(args.wandb_project) > 0 or ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0) 

    train_conf = SFTConfig(
        bf16=True,
        do_eval=False,
        learning_rate=args.learning_rate,
        log_level="info",
        logging_steps=args.logging_steps,
        logging_strategy="steps",
        lr_scheduler_type=args.lr_scheduler_type,
        num_train_epochs=args.epochs,
        max_steps=-1,
        output_dir=args.output_dir,
        overwrite_output_dir=True,
        per_device_train_batch_size=args.train_batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        remove_unused_columns=True,
        save_steps=args.save_steps,
        save_total_limit=1,
        seed=args.seed,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        gradient_accumulation_steps=args.grad_accum_steps,
        warmup_ratio=args.warmup_ratio,
        max_seq_length=args.max_seq_length,
        packing=True,
        report_to="wandb" if use_wandb else "none",
        run_name=args.wandb_run_name if use_wandb else None    
    )    
    
    peft_config = {
        "r": args.lora_r,
        "lora_alpha": args.lora_alpha,
        "lora_dropout": args.lora_dropout,
        "bias": "none",
        "task_type": "CAUSAL_LM",
        #"target_modules": "all-linear",
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        "modules_to_save": None,
    }

    peft_conf = LoraConfig(**peft_config)
    model, tokenizer = load_model(args)

    ###############
    # Setup logging
    ###############
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = train_conf.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process a small summary
    logger.warning(
        f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
        + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
    )
    logger.info(f"Training/evaluation parameters {train_conf}")
    logger.info(f"PEFT parameters {peft_conf}")    

    ##################
    # Data Processing
    ##################
    train_dataset = load_dataset('json', data_files=os.path.join(args.train_dir, 'train.jsonl'), split='train')
    eval_dataset = load_dataset('json', data_files=os.path.join(args.train_dir, 'eval.jsonl'), split='train')
    column_names = list(train_dataset.features)

    processed_train_dataset = train_dataset.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer},
        num_proc=10,
        remove_columns=column_names,
        desc="Applying chat template to train_sft",
    )

    processed_eval_dataset = eval_dataset.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer},
        num_proc=10,
        remove_columns=column_names,
        desc="Applying chat template to test_sft",
    )
    
    with mlflow.start_run() as run:     
        
        log_params_from_dict(training_config, mlflow)
        log_params_from_dict(peft_config, mlflow)
        
        ###########
        # Training
        ###########
        trainer = SFTTrainer(
            model=model,
            args=train_conf,
            peft_config=peft_conf,
            train_dataset=processed_train_dataset,
            eval_dataset=processed_eval_dataset,
            tokenizer=tokenizer
        )

        # Show current memory stats
        gpu_stats = torch.cuda.get_device_properties(0)
        start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
        max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
        logger.info(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
        logger.info(f"{start_gpu_memory} GB of memory reserved.")

        trainer_stats = trainer.train()

        #############
        # Logging
        #############
        metrics = trainer_stats.metrics

        # Show final memory and time stats 
        used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
        used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
        used_percentage = round(used_memory         /max_memory*100, 3)
        lora_percentage = round(used_memory_for_lora/max_memory*100, 3)

        logger.info(f"{metrics['train_runtime']} seconds used for training.")
        logger.info(f"{round(metrics['train_runtime']/60, 2)} minutes used for training.")
        logger.info(f"Peak reserved memory = {used_memory} GB.")
        logger.info(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
        logger.info(f"Peak reserved memory % of max memory = {used_percentage} %.")
        logger.info(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
                
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        
        #############
        # Evaluation
        #############
        # tokenizer.padding_side = "left"
        # metrics = trainer.evaluate()
        # metrics["eval_samples"] = len(processed_eval_dataset)
        # trainer.log_metrics("eval", metrics)
        # trainer.save_metrics("eval", metrics)

        # ############
        # # Save model
        # ############
        os.makedirs(args.model_dir, exist_ok=True)

        if args.save_merged_model:
            model_tmp_dir = "model_tmp"
            os.makedirs(model_tmp_dir, exist_ok=True)
            trainer.model.save_pretrained(model_tmp_dir)
            print(f"Save merged model: {args.model_dir}")
            from peft import AutoPeftModelForCausalLM
            model = AutoPeftModelForCausalLM.from_pretrained(model_tmp_dir, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
            merged_model = model.merge_and_unload()
            merged_model.save_pretrained(args.model_dir, safe_serialization=True)
        else:
            print(f"Save PEFT model: {args.model_dir}")    
            trainer.model.save_pretrained(args.model_dir)

        tokenizer.save_pretrained(args.model_dir)