def main()

in phi3/src_train/train.py [0:0]


def main(args):
    
    ###################
    # Hyper-parameters
    ###################
    # Only overwrite environ if wandb param passed
    if len(args.wandb_project) > 0:
        os.environ['WANDB_API_KEY'] = args.wandb_api_key    
        os.environ["WANDB_PROJECT"] = args.wandb_project
    if len(args.wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = args.wandb_watch
    if len(args.wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = args.wandb_log_model
        
    use_wandb = len(args.wandb_project) > 0 or ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0) 
        
    peft_config = {
        "r": args.lora_r,
        "lora_alpha": args.lora_alpha,
        "lora_dropout": args.lora_dropout,
        "bias": "none",
        "task_type": "CAUSAL_LM",
        #"target_modules": "all-linear",
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        "modules_to_save": None,
    }
    
    train_conf = SFTConfig(
        bf16=True,
        do_eval=False,
        learning_rate=args.learning_rate,
        log_level="info",
        logging_steps=args.logging_steps,
        logging_strategy="steps",
        lr_scheduler_type=args.lr_scheduler_type,
        num_train_epochs=args.epochs,
        max_steps=-1,
        output_dir=args.output_dir,
        overwrite_output_dir=True,
        per_device_train_batch_size=args.train_batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        remove_unused_columns=True,
        save_steps=args.save_steps,
        save_total_limit=1,
        seed=args.seed,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        gradient_accumulation_steps=args.grad_accum_steps,
        warmup_ratio=args.warmup_ratio,
        max_seq_length=args.max_seq_length,
        packing=True,
        report_to="wandb" if use_wandb else "none",
        run_name=args.wandb_run_name if use_wandb else None    
    )    
    
    checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
    
    peft_conf = LoraConfig(**peft_config)
    model, tokenizer = load_model(args)

    ###############
    # Setup logging
    ###############
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = train_conf.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process a small summary
    logger.warning(
        f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
        + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
    )
    logger.info(f"Training/evaluation parameters {train_conf}")
    logger.info(f"PEFT parameters {peft_conf}")    
    
    ##################
    # Data Processing
    ##################
    train_dataset = load_dataset('json', data_files=os.path.join(args.train_dir, 'train.jsonl'), split='train')
    eval_dataset = load_dataset('json', data_files=os.path.join(args.train_dir, 'eval.jsonl'), split='train')
    column_names = list(train_dataset.features)

    processed_train_dataset = train_dataset.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer},
        num_proc=10,
        remove_columns=column_names,
        desc="Applying chat template to train_sft",
    )

    processed_eval_dataset = eval_dataset.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer},
        num_proc=10,
        remove_columns=column_names,
        desc="Applying chat template to test_sft",
    )

    ###########
    # Training
    ###########
    trainer = SFTTrainer(
        model=model,
        args=train_conf,
        peft_config=peft_conf,
        train_dataset=processed_train_dataset,
        eval_dataset=processed_eval_dataset,
        tokenizer=tokenizer
    )
    
    # Show current memory stats
    gpu_stats = torch.cuda.get_device_properties(0)
    start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
    logger.info(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
    logger.info(f"{start_gpu_memory} GB of memory reserved.")

    last_checkpoint = None
    if os.path.isdir(checkpoint_dir):
        checkpoints = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)]
        if len(checkpoints) > 0:
            checkpoints.sort(key=os.path.getmtime, reverse=True)
            last_checkpoint = checkpoints[0]        

    trainer_stats = trainer.train(resume_from_checkpoint=last_checkpoint)
    
    metrics = trainer_stats.metrics
 
    # Show final memory and time stats 
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
    used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
    used_percentage = round(used_memory         /max_memory*100, 3)
    lora_percentage = round(used_memory_for_lora/max_memory*100, 3)

    logger.info(f"{metrics['train_runtime']} seconds used for training.")
    logger.info(f"{round(metrics['train_runtime']/60, 2)} minutes used for training.")
    logger.info(f"Peak reserved memory = {used_memory} GB.")
    logger.info(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
    logger.info(f"Peak reserved memory % of max memory = {used_percentage} %.")
    logger.info(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
    
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)
    trainer.save_state()
    
    #############
    # Evaluation
    #############
    tokenizer.padding_side = "left"
    metrics = trainer.evaluate()
    metrics["eval_samples"] = len(processed_eval_dataset)
    trainer.log_metrics("eval", metrics)
    trainer.save_metrics("eval", metrics)
    
    # ############
    # # Save model
    # ############
    os.makedirs(args.model_dir, exist_ok=True)

    if args.save_merged_model:
        model_tmp_dir = "model_tmp"
        os.makedirs(model_tmp_dir, exist_ok=True)
        trainer.model.save_pretrained(model_tmp_dir)
        print(f"Save merged model: {args.model_dir}")
        from peft import AutoPeftModelForCausalLM
        model = AutoPeftModelForCausalLM.from_pretrained(model_tmp_dir, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
        merged_model = model.merge_and_unload()
        merged_model.save_pretrained(args.model_dir, safe_serialization=True)
    else:
        print(f"Save PEFT model: {args.model_dir}")    
        trainer.model.save_pretrained(args.model_dir)

    tokenizer.save_pretrained(args.model_dir)