from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, ColSpec

import os
import mlflow
from mlflow.models import infer_signature
import argparse
import sys
import logging

import datasets
from datasets import load_dataset
from peft import LoraConfig
import torch
import transformers
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from datetime import datetime

logger = logging.getLogger(__name__)

def log_params_from_dict(config, mlflow_client, parent_key=''):
    """
    Given a dictionary of parameters, logs non-dictionary values to the specified mlflow client.
    Ignores nested dictionaries.

    Args:
        config (dict): The dictionary of parameters to log.
        mlflow_client: The mlflow client to use for logging.
        parent_key (str): Used to prefix keys (for nested logging).
    """
    for key, value in config.items():
        if isinstance(value, dict):
            continue
        elif isinstance(value, list):
            full_key = f"{parent_key}.{key}" if parent_key else key
            mlflow_client.log_param(full_key, ','.join(map(str, value)))
        else:
            full_key = f"{parent_key}.{key}" if parent_key else key
            mlflow_client.log_param(full_key, value)
            

def load_model(args):

    model_name_or_path = args.model_name_or_path    
    model_kwargs = dict(
        use_cache=False,
        trust_remote_code=True,
        #attn_implementation="flash_attention_2",  # loading the model with flash-attenstion support
        torch_dtype=torch.bfloat16,
        device_map=None
    )
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_kwargs)
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    tokenizer.model_max_length = args.max_seq_length
    tokenizer.pad_token = tokenizer.unk_token  # use unk rather than eos token to prevent endless generation
    tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
    tokenizer.padding_side = "right"
    return model, tokenizer

def apply_chat_template(
    example,
    tokenizer,
):
    messages = example["messages"]
    # Add an empty system message if there is none
    if messages[0]["role"] != "system":
        messages.insert(0, {"role": "system", "content": ""})
    example["text"] = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False)
    return example

def main(args):

    ###################
    # Hyper-parameters
    ###################
    # Only overwrite environ if wandb param passed
    if len(args.wandb_project) > 0:
        os.environ['WANDB_API_KEY'] = args.wandb_api_key    
        os.environ["WANDB_PROJECT"] = args.wandb_project
    if len(args.wandb_watch) > 0:
        os.environ["WANDB_WATCH"] = args.wandb_watch
    if len(args.wandb_log_model) > 0:
        os.environ["WANDB_LOG_MODEL"] = args.wandb_log_model

    use_wandb = len(args.wandb_project) > 0 or ("WANDB_PROJECT" in os.environ and len(os.environ["WANDB_PROJECT"]) > 0) 

    train_conf = SFTConfig(
        bf16=True,
        do_eval=False,
        learning_rate=args.learning_rate,
        log_level="info",
        logging_steps=args.logging_steps,
        logging_strategy="steps",
        lr_scheduler_type=args.lr_scheduler_type,
        num_train_epochs=args.epochs,
        max_steps=-1,
        output_dir=args.output_dir,
        overwrite_output_dir=True,
        per_device_train_batch_size=args.train_batch_size,
        per_device_eval_batch_size=args.eval_batch_size,
        remove_unused_columns=True,
        save_steps=args.save_steps,
        save_total_limit=1,
        seed=args.seed,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        gradient_accumulation_steps=args.grad_accum_steps,
        warmup_ratio=args.warmup_ratio,
        max_seq_length=args.max_seq_length,
        packing=True,
        report_to="wandb" if use_wandb else "none",
        run_name=args.wandb_run_name if use_wandb else None    
    )    
    
    peft_config = {
        "r": args.lora_r,
        "lora_alpha": args.lora_alpha,
        "lora_dropout": args.lora_dropout,
        "bias": "none",
        "task_type": "CAUSAL_LM",
        #"target_modules": "all-linear",
        "target_modules": ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        "modules_to_save": None,
    }

    peft_conf = LoraConfig(**peft_config)
    model, tokenizer = load_model(args)

    ###############
    # Setup logging
    ###############
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    log_level = train_conf.get_process_log_level()
    logger.setLevel(log_level)
    datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process a small summary
    logger.warning(
        f"Process rank: {train_conf.local_rank}, device: {train_conf.device}, n_gpu: {train_conf.n_gpu}"
        + f" distributed training: {bool(train_conf.local_rank != -1)}, 16-bits training: {train_conf.fp16}"
    )
    logger.info(f"Training/evaluation parameters {train_conf}")
    logger.info(f"PEFT parameters {peft_conf}")    

    ##################
    # Data Processing
    ##################
    train_dataset = load_dataset('json', data_files=os.path.join(args.train_dir, 'train.jsonl'), split='train')
    eval_dataset = load_dataset('json', data_files=os.path.join(args.train_dir, 'eval.jsonl'), split='train')
    column_names = list(train_dataset.features)

    processed_train_dataset = train_dataset.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer},
        num_proc=10,
        remove_columns=column_names,
        desc="Applying chat template to train_sft",
    )

    processed_eval_dataset = eval_dataset.map(
        apply_chat_template,
        fn_kwargs={"tokenizer": tokenizer},
        num_proc=10,
        remove_columns=column_names,
        desc="Applying chat template to test_sft",
    )
    
    with mlflow.start_run() as run:     
        
        log_params_from_dict(training_config, mlflow)
        log_params_from_dict(peft_config, mlflow)
        
        ###########
        # Training
        ###########
        trainer = SFTTrainer(
            model=model,
            args=train_conf,
            peft_config=peft_conf,
            train_dataset=processed_train_dataset,
            eval_dataset=processed_eval_dataset,
            tokenizer=tokenizer
        )

        # Show current memory stats
        gpu_stats = torch.cuda.get_device_properties(0)
        start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
        max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
        logger.info(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
        logger.info(f"{start_gpu_memory} GB of memory reserved.")

        trainer_stats = trainer.train()

        #############
        # Logging
        #############
        metrics = trainer_stats.metrics

        # Show final memory and time stats 
        used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
        used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
        used_percentage = round(used_memory         /max_memory*100, 3)
        lora_percentage = round(used_memory_for_lora/max_memory*100, 3)

        logger.info(f"{metrics['train_runtime']} seconds used for training.")
        logger.info(f"{round(metrics['train_runtime']/60, 2)} minutes used for training.")
        logger.info(f"Peak reserved memory = {used_memory} GB.")
        logger.info(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
        logger.info(f"Peak reserved memory % of max memory = {used_percentage} %.")
        logger.info(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")
                
        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()
        
        #############
        # Evaluation
        #############
        # tokenizer.padding_side = "left"
        # metrics = trainer.evaluate()
        # metrics["eval_samples"] = len(processed_eval_dataset)
        # trainer.log_metrics("eval", metrics)
        # trainer.save_metrics("eval", metrics)

        # ############
        # # Save model
        # ############
        os.makedirs(args.model_dir, exist_ok=True)

        if args.save_merged_model:
            model_tmp_dir = "model_tmp"
            os.makedirs(model_tmp_dir, exist_ok=True)
            trainer.model.save_pretrained(model_tmp_dir)
            print(f"Save merged model: {args.model_dir}")
            from peft import AutoPeftModelForCausalLM
            model = AutoPeftModelForCausalLM.from_pretrained(model_tmp_dir, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
            merged_model = model.merge_and_unload()
            merged_model.save_pretrained(args.model_dir, safe_serialization=True)
        else:
            print(f"Save PEFT model: {args.model_dir}")    
            trainer.model.save_pretrained(args.model_dir)

        tokenizer.save_pretrained(args.model_dir)             


def parse_args():
    # setup argparse
    parser = argparse.ArgumentParser()
    # curr_time = datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

    # hyperparameters
    parser.add_argument("--model_name_or_path", default="microsoft/Phi-3.5-mini-instruct", type=str, help="Input directory for training")    
    parser.add_argument("--train_dir", default="data", type=str, help="Input directory for training")
    parser.add_argument("--model_dir", default="./model", type=str, help="output directory for model")
    parser.add_argument("--epochs", default=1, type=int, help="number of epochs")
    parser.add_argument("--train_batch_size", default=8, type=int, help="training - mini batch size for each gpu/process")
    parser.add_argument("--eval_batch_size", default=8, type=int, help="evaluation - mini batch size for each gpu/process")
    parser.add_argument("--learning_rate", default=5e-06, type=float, help="learning rate")
    parser.add_argument("--logging_steps", default=2, type=int, help="logging steps")
    parser.add_argument("--save_steps", default=100, type=int, help="save steps")    
    parser.add_argument("--grad_accum_steps", default=4, type=int, help="gradient accumulation steps")
    parser.add_argument("--lr_scheduler_type", default="linear", type=str)
    parser.add_argument("--seed", default=0, type=int, help="seed")
    parser.add_argument("--warmup_ratio", default=0.2, type=float, help="warmup ratio")
    parser.add_argument("--max_seq_length", default=2048, type=int, help="max seq length")
    parser.add_argument("--save_merged_model", type=bool, default=False)

    # lora hyperparameters
    parser.add_argument("--lora_r", default=16, type=int, help="lora r")
    parser.add_argument("--lora_alpha", default=16, type=int, help="lora alpha")
    parser.add_argument("--lora_dropout", default=0.05, type=float, help="lora dropout")
    
    # wandb params
    parser.add_argument("--wandb_api_key", type=str, default="")
    parser.add_argument("--wandb_project", type=str, default="")
    parser.add_argument("--wandb_run_name", type=str, default="")
    parser.add_argument("--wandb_watch", type=str, default="gradients") # options: false | gradients | all
    parser.add_argument("--wandb_log_model", type=str, default="false") # options: false | true

    # parse args
    args = parser.parse_args()

    # return args
    return args

if __name__ == "__main__":
    #sys.argv = ['']
    args = parse_args()
    main(args)
