from collections import OrderedDict
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, List

import torch
from accelerate import Accelerator

# from accelerate.utils import compile_regions
from datasets import Dataset
from transformers import Trainer, TrainerCallback, TrainerState, TrainingArguments
from transformers.quantizers import AutoQuantizationConfig

from ...import_utils import (
    is_deepspeed_available,
    is_gptqmodel_available,
    is_torch_distributed_available,
    is_zentorch_available,
)
from ..base import Backend
from ..peft_utils import apply_peft
from ..transformers_utils import fast_weights_init
from .config import PyTorchConfig

if is_deepspeed_available():
    import deepspeed  # type: ignore

if is_torch_distributed_available():
    import torch.distributed  # type: ignore

if is_zentorch_available():
    import zentorch  # type: ignore # noqa: F401

if is_gptqmodel_available():
    import enum

    if not hasattr(enum, "EnumType") and hasattr(enum, "EnumMeta"):
        # This is a workaround for a bug in gptqmodel where it tries to access EnumType
        # from the enum module, but it is not available in Python 3.10 and below.
        enum.EnumType = enum.EnumMeta


class PyTorchBackend(Backend[PyTorchConfig]):
    NAME = "pytorch"

    def __init__(self, config: PyTorchConfig):
        super().__init__(config)

        # Threads
        if self.config.inter_op_num_threads is not None:
            self.logger.info(f"\t+ Setting pytorch inter_op_num_threads({self.config.inter_op_num_threads}))")
            torch.set_num_threads(self.config.inter_op_num_threads)

        if self.config.intra_op_num_threads is not None:
            self.logger.info(f"\t+ Setting pytorch intra_op_num_threads({self.config.intra_op_num_threads}))")
            torch.set_num_interop_threads(self.config.intra_op_num_threads)

        # TF32
        if self.config.allow_tf32:
            self.logger.info("\t+ Enabling TF32")
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True

        # Autocast
        if self.config.autocast_enabled:
            self.logger.info("\t+ Enabling automatic mixed precision")
            torch.set_autocast_enabled(True)

            if self.config.autocast_dtype is not None:
                if self.config.device == "cpu":
                    self.logger.info(f"\t+ Setting autocast cpu dtype to {self.config.autocast_dtype}")
                    torch.set_autocast_cpu_dtype(getattr(torch, self.config.autocast_dtype))
                elif self.config.device == "cuda":
                    self.logger.info(f"\t+ Setting autocast gpu dtype to {self.config.autocast_dtype}")
                    torch.set_autocast_gpu_dtype(getattr(torch, self.config.autocast_dtype))
                else:
                    raise ValueError(f"Device {self.config.device} not supported for autocast")

    def load(self) -> None:
        self.logger.info("\t+ Creating backend temporary directory")
        self.tmpdir = TemporaryDirectory()

        if self.config.library == "transformers":
            self.load_transformers_model()
        elif self.config.library == "diffusers":
            self.load_diffusers_model()
        elif self.config.library == "timm":
            self.load_timm_model()
        else:
            raise ValueError(f"Library {self.config.library} not supported for PyTorch backend")

        self.logger.info("\t+ Cleaning up backend temporary directory")
        self.tmpdir.cleanup()

    def load_transformers_model_from_pretrained(self) -> None:
        self.logger.info("\t+ Loading Transformers model")
        self.pretrained_model = self.automodel_loader.from_pretrained(
            pretrained_model_name_or_path=self.config.model,
            **self.config.model_kwargs,
            **self.automodel_kwargs,
        )
        if self.config.device_map is None and self.config.device != "cpu":
            self.logger.info(f"\t+ Moving Transformers model to device: {self.config.device}")
            self.pretrained_model = self.pretrained_model.to(self.config.device)

    def load_transformers_model_with_no_weights(self) -> None:
        with fast_weights_init():
            original_model, self.config.model = self.config.model, self.no_weights_model_path.as_posix()
            self.load_transformers_model_from_pretrained()
            self.config.model = original_model

    def load_transformers_model(self):
        if self.config.deepspeed_inference and self.is_quantized:
            raise ValueError("Deepspeed-Inference is not compatible with Transformers quantization")

        # Quantization
        if self.is_quantized:
            self.logger.info("\t+ Processing AutoQuantization config")
            self.quantization_config = AutoQuantizationConfig.from_dict(
                dict(
                    getattr(self.pretrained_config, "quantization_config", {}),
                    **self.config.quantization_config,
                )
            )

        # Model loading
        if self.config.no_weights:
            self.logger.info("\t+ Creating no weights model")
            if self.config.tp_plan is not None:
                self.create_no_weights_model_slow()
            else:
                self.create_no_weights_model_fast()
            self.logger.info("\t+ Loading model with random weights")
            self.load_transformers_model_with_no_weights()
        else:
            self.logger.info("\t+ Loading model with pretrained weights")
            self.load_transformers_model_from_pretrained()

        # KV-Cache
        if self.config.cache_implementation is not None:
            self.logger.info(f"\t+ Setting cache implementation to {self.config.cache_implementation}")
            self.pretrained_model.generation_config.cache_implementation = self.config.cache_implementation

        # BetterTransformer
        if self.config.to_bettertransformer:
            self.logger.info("\t+ To BetterTransformer")
            self.pretrained_model.to_bettertransformer()

        # Eval mode
        if self.config.eval_mode:
            self.logger.info("\t+ Enabling eval mode")
            self.pretrained_model.eval()

        # PEFT
        if self.config.peft_type is not None:
            self.logger.info("\t+ Applying PEFT")
            self.pretrained_model = apply_peft(self.pretrained_model, self.config.peft_type, self.config.peft_config)

        # DeepSpeed
        if self.config.deepspeed_inference:
            self.logger.info("\t+ Initializing DeepSpeed Inference Engine")
            self.pretrained_model = deepspeed.init_inference(
                model=self.pretrained_model, config=self.config.deepspeed_inference_config
            )

        # Torch compile
        if self.config.torch_compile:
            if self.config.torch_compile_target == "model":
                self.logger.info("\t+ Using torch.compile on model")
                self.pretrained_model = torch.compile(self.pretrained_model, **self.config.torch_compile_config)
            # elif self.config.torch_compile_target == "regions":
            #     self.logger.info("\t+ Using torch.compile on regions")
            #     self.pretrained_model = compile_regions(self.pretrained_model, **self.config.torch_compile_config)
            elif self.config.torch_compile_target == "forward":
                self.logger.info("\t+ Using torch.compile on forward")
                self.pretrained_model.forward = torch.compile(
                    self.pretrained_model.forward, **self.config.torch_compile_config
                )
            else:
                raise ValueError(f"Target {self.config.torch_compile_target} not supported")

    def load_diffusers_pipeline_from_pretrained(self) -> None:
        self.pretrained_model = self.automodel_loader.from_pretrained(
            self.config.model,
            # pretrained_model_name_or_path=self.config.model,
            # pretrained_model_or_path=self.config.model,
            device_map=self.config.device_map,
            **self.config.model_kwargs,
            **self.automodel_kwargs,
        )
        if self.config.device_map is None and self.config.device != "cpu":
            self.logger.info(f"\t+ Moving Diffusion Pipeline to device: {self.config.device}")
            self.pretrained_model = self.pretrained_model.to(self.config.device)

    def load_diffusers_model(self):
        self.logger.info("\t+ Loading Diffusion Pipeline")
        self.logger.info(f"\t+ Using Diffusers Pipeline {self.automodel_loader.__name__}")

        # Model loading
        if self.config.no_weights:
            raise ValueError("No weights model not supported for Diffusers")
        else:
            self.load_diffusers_pipeline_from_pretrained()

        # Torch compile
        if self.config.torch_compile:
            self.logger.info("\t+ Using torch.compile on unet and vae")
            self.pretrained_model.unet = torch.compile(self.pretrained_model.unet, **self.config.torch_compile_config)
            self.pretrained_model.vae.decode = torch.compile(
                self.pretrained_model.vae.decode, **self.config.torch_compile_config
            )

    def load_timm_model_form_pretrained(self) -> None:
        self.pretrained_model = self.automodel_loader(model_name=self.config.model)
        if self.config.device != "cpu":
            self.logger.info(f"\t+ Moving Timm model to device: {self.config.device}")
            self.pretrained_model = self.pretrained_model.to(self.config.device)

    def load_timm_model(self):
        self.logger.info("\t+ Loading Timm model")
        self.logger.info(f"\t+ Using Timm's {self.automodel_loader.__name__}")

        # Model loading
        if self.config.no_weights:
            raise ValueError("No weights model not supported for Timm")
        else:
            self.load_timm_model_form_pretrained()

        # Torch compile
        if self.config.torch_compile:
            if self.config.torch_compile_target == "forward":
                self.logger.info("\t+ Using torch.compile on forward")
                self.pretrained_model.forward = torch.compile(
                    self.pretrained_model.forward, **self.config.torch_compile_config
                )
            elif self.config.torch_compile_target == "model":
                self.logger.info("\t+ Using torch.compile on model")
                self.pretrained_model = torch.compile(self.pretrained_model, **self.config.torch_compile_config)
            else:
                raise ValueError(f"Target {self.config.torch_compile_target} not supported")

    @property
    def is_quantized(self) -> bool:
        return self.config.quantization_scheme is not None or (
            hasattr(self.pretrained_config, "quantization_config")
            and self.pretrained_config.quantization_config.get("quant_method") is not None
        )

    @property
    def is_gptq_quantized(self) -> bool:
        return self.config.quantization_scheme == "gptq" or (
            hasattr(self.pretrained_config, "quantization_config")
            and self.pretrained_config.quantization_config.get("quant_method") == "gptq"
        )

    @property
    def is_bnb_quantized(self) -> bool:
        return self.config.quantization_scheme == "bnb" or (
            hasattr(self.pretrained_config, "quantization_config")
            and self.pretrained_config.quantization_config.get("quant_method") == "bnb"
        )

    @property
    def is_exllamav2(self) -> bool:
        return (
            self.is_quantized
            and (self.is_gptq_quantized)
            and (
                (
                    hasattr(self.pretrained_config, "quantization_config")
                    and hasattr(self.pretrained_config.quantization_config, "exllama_config")
                    and self.pretrained_config.quantization_config.exllama_config.get("version") == 2
                )
                or (
                    "exllama_config" in self.config.quantization_config
                    and self.config.quantization_config["exllama_config"].get("version") == 2
                )
            )
        )

    @property
    def automodel_kwargs(self) -> Dict[str, Any]:
        kwargs = {}

        if self.config.torch_dtype is not None:
            if hasattr(torch, self.config.torch_dtype):
                kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype)
            else:
                kwargs["torch_dtype"] = self.config.torch_dtype

        if self.is_quantized:
            kwargs["quantization_config"] = self.quantization_config

        if self.config.attn_implementation is not None:
            kwargs["attn_implementation"] = self.config.attn_implementation

        if self.config.low_cpu_mem_usage is not None:
            kwargs["low_cpu_mem_usage"] = self.config.low_cpu_mem_usage

        if self.config.device_map is not None:
            kwargs["device_map"] = self.config.device_map

        if self.config.tp_plan is not None:
            kwargs["tp_plan"] = self.config.tp_plan

        return kwargs

    @property
    def split_between_processes(self) -> bool:
        return (
            is_torch_distributed_available()
            and torch.distributed.is_initialized()
            and not self.config.deepspeed_inference
        )

    def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        if self.split_between_processes:
            with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs:
                inputs = process_inputs

        for key, value in inputs.items():
            if isinstance(value, torch.Tensor):
                inputs[key] = value.to(self.config.device)

        if self.config.library == "timm":
            inputs = {"x": inputs["pixel_values"]}

        return inputs

    @torch.inference_mode()
    def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
        return self.pretrained_model.forward(**inputs, **kwargs)

    @torch.inference_mode()
    def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
        assert kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1, (
            "For prefilling, max_new_tokens and min_new_tokens must be equal to 1"
        )
        return self.pretrained_model.generate(**inputs, **kwargs)

    @torch.inference_mode()
    def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
        return self.pretrained_model.generate(**inputs, **kwargs)

    @torch.inference_mode()
    def call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict:
        return self.pretrained_model(**inputs, **kwargs)

    def train(
        self,
        training_dataset: Dataset,
        training_arguments: Dict[str, Any],
        training_callbacks: List[TrainerCallback],
        training_data_collator: Callable[[List[Dict[str, Any]]], Dict[str, Any]],
    ) -> TrainerState:
        self.logger.info(f"\t+ Wrapping training arguments with {TrainingArguments.__name__}")
        training_arguments["use_cpu"] = self.config.device == "cpu"
        training_arguments = TrainingArguments(**training_arguments)
        self.logger.info(f"\t+ Wrapping model with {Trainer.__name__}")
        trainer = Trainer(
            args=training_arguments,
            model=self.pretrained_model,
            callbacks=training_callbacks,
            train_dataset=training_dataset,
            data_collator=training_data_collator,
        )
        self.logger.info("\t+ Starting training")
        trainer.train()
        self.logger.info("\t+ Finished training")
