import os
import math
import logging
import pathlib
from dataclasses import dataclass, field
from typing import Optional, List, Tuple
import wandb

import torch
import transformers
from transformers import (
    HfArgumentParser,
    AutoConfig,
    AutoProcessor,
    TrainingArguments,
    set_seed
)


import torch.distributed as dist
# LoRA / PEFT if needed
try:
    from peft import (
        LoraConfig,
        get_peft_model,
        prepare_model_for_kbit_training,
    )
except ImportError:
    LoraConfig = None
    get_peft_model = None
    prepare_model_for_kbit_training = None

# BitsAndBytes if needed
try:
    from transformers import BitsAndBytesConfig
except ImportError:
    BitsAndBytesConfig = None

from smolvlm.train.smolvlm_trainer import SmolVLMTrainer
from smolvlm.train.args import DataArguments, ModelArguments, TrainingArguments
from smolvlm.datasets.builder import make_supervised_data_module

logger = logging.getLogger(__name__)

#TODO: check what these do. 
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'


def trainer_save_model_safe(trainer: SmolVLMTrainer):
    """
    Safely saves the model if not in DeepSpeed ZeRO stage-3.
    """
    if trainer.is_deepspeed_enabled:
        trainer.save_model()
    else:
        state_dict = trainer.model.state_dict()
        cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
        del state_dict
        trainer._save(trainer.args.output_dir, state_dict=cpu_state_dict)


def get_nb_trainable_parameters(model: torch.nn.Module) -> Tuple[int, int]:
    """
    Returns (trainable_params, total_params) across the entire model.
    """
    trainable_params = 0
    total_params = 0
    for _, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    return trainable_params, total_params


def set_trainable_params(model: torch.nn.Module, training_args: TrainingArguments):
    """
    Freezes all parameters, then selectively unfreezes based on user flags:
     - tune_vision_tower => unfreeze vision tower
     - tune_mm_connector => unfreeze connector
     - tune_language_model => unfreeze base language model
    Prints out which modules are unfrozen/frozen for clarity.
    """

    
    for param_name, param in model.named_parameters():
        param.requires_grad = False

    vis_unfrozen, conn_unfrozen, llm_unfrozen = 0, 0, 0
    vis_total, conn_total, llm_total = 0, 0, 0

    for name, param in model.named_parameters():
        if ("vision_model" in name) or ("vision_tower" in name):
            vis_total += param.numel()
            if training_args.tune_vision_tower:
                param.requires_grad = True
                vis_unfrozen += param.numel()

        elif ("connector" in name) or ("modality_projection" in name) or ("merger" in name):
            conn_total += param.numel()
            if training_args.tune_mm_connector:
                param.requires_grad = True
                conn_unfrozen += param.numel()
        else:
            
            llm_total += param.numel()
            if training_args.tune_language_model:
                param.requires_grad = True
                llm_unfrozen += param.numel()

    # Summaries
    trainable_params, total_params = get_nb_trainable_parameters(model)
    pct = 100.0 * trainable_params / max(total_params, 1)

    logger.info("=== Freeze/Unfreeze Summary ===")
    logger.info(f"  tune_vision_tower={training_args.tune_vision_tower} "
                f"(unfrozen {vis_unfrozen:,d}/{vis_total:,d} params)")
    logger.info(f"  tune_mm_connector={training_args.tune_mm_connector} "
                f"(unfrozen {conn_unfrozen:,d}/{conn_total:,d} params)")
    logger.info(f"  tune_language_model={training_args.tune_language_model} "
                f"(unfrozen {llm_unfrozen:,d}/{llm_total:,d} params)")
    logger.info(f"  => Overall trainable params: {trainable_params:,d} / {total_params:,d} "
                f"({pct:.2f}%)\n")


def enable_gradient_checkpointing(model: torch.nn.Module, training_args: TrainingArguments):
    """
    Enables gradient checkpointing if specified in TrainingArguments.
    If the model's LLM submodule supports enabling input grads, we do so (like Qwen does).
    Otherwise, we attach a forward hook to the input embedding to require grads.
    """

    logger.info("Enabling gradient checkpointing in the model.")
    model.config.use_cache = False
    model.config.use_reentrant_checkpointing = False
    # model.gradient_checkpointing_enable()
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    if hasattr(model, "model") and hasattr(model.model, "llm"):
        # Some models have a method like `enable_input_require_grads`.
        if hasattr(model.model.llm, "enable_input_require_grads"):
            logger.info("Calling model.model.llm.enable_input_require_grads() for better GC.")
            model.model.llm.enable_input_require_grads()
        else:
            # fallback approach
            logger.info("Attaching a forward hook to require grad on embeddings output.")
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)
            input_embed = model.get_input_embeddings() if hasattr(model, "get_input_embeddings") else None
            if input_embed is not None:
                input_embed.register_forward_hook(make_inputs_require_grad)


def prepare_model(
    model_args: ModelArguments,
    training_args: TrainingArguments
):
    """
    Loads and configures the smolVLM model (Idefics3ForConditionalGeneration),
    applying rope scaling if needed, plus optional bitsandbytes quant config.
    """
    logger.info("Loading config from %s", model_args.model_name_or_path)
    config = AutoConfig.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        trust_remote_code=model_args.trust_remote_code,
    )
    compute_dtype = (
        torch.float16 if training_args.fp16
        else (torch.bfloat16 if training_args.bf16 else torch.float32)
    )
    # Possibly adjust rope scaling for a longer context
    orig_ctx_len = getattr(config, "max_position_embeddings", None)
    if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
        factor = math.ceil(training_args.model_max_length / orig_ctx_len)
        config.rope_scaling = {"type": "linear", "factor": factor}
        logger.info(f"Auto rope scaling => from {orig_ctx_len} to {training_args.model_max_length}. Factor={factor}")

    # For training, disable cache to reduce memory usage
    config.use_cache = False

    bnb_args = {}
    # If using bitsandbytes in 4- or 8-bit
    if BitsAndBytesConfig is not None and training_args.bits in [4, 8]:
        bnb_args["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=(training_args.bits == 4),
            load_in_8bit=(training_args.bits == 8),
            llm_int8_skip_modules=["lm_head"],
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=training_args.double_quant,
            bnb_4bit_quant_type=training_args.quant_type,
        )
        logger.info(f"Using bitsandbytes quantization: bits={training_args.bits}, type={training_args.quant_type}")

    # Possibly set attention impl
    # if training_args.disable_flash_attn2:
    #     attn_impl = "sdpa"
    # else: 
    #     attn_impl = "flash_attention_2"
    # logger.info("Instantiating SmolVLMForConditionalGeneration with attention impl=%s", attn_impl)

    if training_args.disable_flash_attn2:
        config._attn_implementation = "sdpa"
    else: 
        config._attn_implementation = "flash_attention_2"

    if model_args.frames_per_clip > 1:
        from smolvlm.model.modeling_smollmm import SmolLMMForConditionalGeneration
        config.frames_per_clip = model_args.frames_per_clip
        model_cls = SmolLMMForConditionalGeneration
        logger.info(f"Using frame emmbedding averaging of {model_args.frames_per_clip} frames")
    else:
        from smolvlm.model.modeling_smolvlm import SmolVLMForConditionalGeneration
        model_cls = SmolVLMForConditionalGeneration
    
    model = model_cls.from_pretrained(
        model_args.model_name_or_path,
        torch_dtype=compute_dtype,
        config = config,
        **bnb_args,
    )

    return model


def apply_peft(model: torch.nn.Module, training_args: TrainingArguments) -> torch.nn.Module:
    """
    Applies LoRA/PEFT if training_args.peft_enable is True.
    Also calls `prepare_model_for_kbit_training` if bits=4 or 8.
    """

    if (LoraConfig is None) or (get_peft_model is None):
        raise ValueError("PEFT is not installed, but peft_enable=True was set.")

    logger.info("PEFT/LoRA is enabled. Building LoRA config...")

    # If user hasn't provided specific modules, pick a guess
    peft_target_modules = training_args.target_modules
    if not peft_target_modules:
        peft_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

    # For 4bit/8bit, we can do some prep:
    if training_args.bits in [4, 8] and prepare_model_for_kbit_training is not None:
        logger.info("Running `prepare_model_for_kbit_training` for LoRA + 4/8-bit support.")
        model = prepare_model_for_kbit_training(
            model,
            use_gradient_checkpointing=training_args.gradient_checkpointing
        )

    lora_config = LoraConfig(
        r=training_args.lora_rank,
        lora_alpha=training_args.lora_alpha,
        lora_dropout=training_args.lora_dropout,
        target_modules=peft_target_modules,
        bias=training_args.lora_bias,  # "none"/"all"/"lora_only"
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    logger.info("LoRA applied. Trainable parameters:")
    model.print_trainable_parameters()
    return model


def auto_resume_or_start(training_args: TrainingArguments) -> bool:
    """
    Detect if there's a previous checkpoint to resume from.
    Return True if we found a checkpoint. Otherwise, we start fresh.
    """
    ckpts = list(pathlib.Path(training_args.output_dir).glob("checkpoint-*"))
    return len(ckpts) > 0


def train():
    """
    Main fine-tuning entry point for your smolVLM model,
    with optional LoRA + bitsandbytes, and prints which submodules are frozen/unfrozen.
    """
    logging.basicConfig(level=logging.INFO)
    logger.info("Parsing arguments with HfArgumentParser...")

    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    
    if data_args.packed and model_args.apply_diagonal_block_attention:
        from smolvlm.model.varlen_packing import apply_varlen_patch
        apply_varlen_patch()
    elif data_args.packed and not model_args.apply_diagonal_block_attention:
        logger.warn("Sequence packing has being enabled WITHOUGH diagonal block attention!")
    elif not data_args.packed and model_args.apply_diagonal_block_attention:
        logger.warn("diagonal block attention has been enabled WITHOUT sequence packing. Ignoring flag!")    
        
    # Ensure reproducibility
    set_seed(training_args.seed)
    
    # Initialize wandb only on the main process (global rank 0) if wandb logging is enabled
    if "wandb" in training_args.report_to and training_args.local_rank == 0 and dist.get_rank() == 0:
        os.environ["WANDB_PROJECT"] = "smolvlmvideo"  # Set project name
        wandb.init(
            name=training_args.run_name,
            config=training_args.to_dict(),
        )
        # Ensure other processes will not try to log
        os.environ["WANDB_MODE"] = "offline"


    # Possibly set tune flags automatically based on user-provided LR
    training_args.tune_language_model = training_args.language_model_lr > 1e-9
    training_args.tune_mm_connector = training_args.connector_lr > 1e-9
    training_args.tune_vision_tower = training_args.vision_tower_lr > 1e-9

    # 1) Prepare model + config
    logger.info("Preparing model + config (possibly with bitsandbytes) ...")
    model = prepare_model(model_args, training_args)

    # 2) Freeze/unfreeze based on user flags, plus prints
    set_trainable_params(model, training_args)

    # 3) Possibly enable gradient checkpointing
    if training_args.gradient_checkpointing:
        enable_gradient_checkpointing(model, training_args)

    # 4) Possibly apply LoRA/PEFT
    if training_args.peft_enable:
        model = apply_peft_if_needed(model, training_args)

    # 5) Load processor (tokenizer + image processor, etc.)
    #import ipdb; ipdb.set_trace()
    logger.info("Loading AutoProcessor from %s", model_args.model_name_or_path)
    if model_args.frames_per_clip > 1:
        from smolvlm.model.processing_smollmm import SmolLMMProcessor
        processor = SmolLMMProcessor.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side=model_args.padding_side,
            trust_remote_code=model_args.trust_remote_code,
        )
    else:
        processor = AutoProcessor.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side=model_args.padding_side,
            trust_remote_code=model_args.trust_remote_code,
        )

    # 6) Build dataset + collator
    logger.info("Building dataset + collator...")
    data_module = make_supervised_data_module(processor, data_args, training_args, model_args)

    # 7) Initialize custom trainer
    logger.info("Initializing SmolVLMTrainer...")
    trainer = SmolVLMTrainer(
        model=model,
        args=training_args,
        **data_module
    )

    # 8) Possibly auto-resume from checkpoint
    resume_training = auto_resume_or_start(training_args)
    if resume_training:
        logger.info("Resuming from a previous checkpoint in %s ...", training_args.output_dir)
        trainer.train(resume_from_checkpoint=True)
    else:
        logger.info("Starting a fresh training run...")
        trainer.train()

    # 9) Post-training final steps
    logger.info("Training completed. Saving final model...")
    # Re-enable model cache if needed
    model.config.use_cache = True

    # Save trainer state
    trainer.save_state()

    # Save final model (special logic if in Deepspeed)
    if trainer.is_deepspeed_enabled:
        trainer.save_model()
    else:
        trainer_save_model_safe(trainer)

    logger.info("All done. Exiting successfully.")


if __name__ == "__main__":
    train()