# Copyright 2023 The HuggingFace Team. All rights reserved.
# Licensed under the MIT License.

import random
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn
from datasets import load_dataset
from tqdm import tqdm

from optimum.utils.normalized_config import NormalizedConfigManager
from transformers import AutoConfig


if TYPE_CHECKING:
    from .configuration import BrevitasQuantizationConfig

HIDDEN_SIZE_KEYS = ["d_model", "hidden_size"]
NUM_HEADS_KEYS = ["num_attention_heads"]


@torch.no_grad()
def recursive_to_device(tensor_or_iterable: Union[Iterable, torch.Tensor], device) -> None:
    if isinstance(tensor_or_iterable, torch.Tensor):
        return tensor_or_iterable.to(device)
    elif isinstance(tensor_or_iterable, tuple):  # Special handling of tuples, since they are immutable
        tmp_list = []
        for i in tensor_or_iterable:
            tmp_list.append(recursive_to_device(i, device))
        return tuple(tmp_list)
    elif isinstance(tensor_or_iterable, Iterable):
        for i in tensor_or_iterable:
            tensor_or_iterable[i] = recursive_to_device(i, device)
        return tensor_or_iterable
    else:
        raise ValueError(f"Cannot move {type(tensor_or_iterable)} to {device}")


@torch.no_grad()
def compute_perplexity(model: torch.nn.Module, data: List[Dict], context_length: int, tokenizer: Any, seed: int = 0):
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)

    model = model.eval()

    cross_entropy_loss = nn.CrossEntropyLoss()

    nlls = []
    for sample in tqdm(data, desc="Computing perplexity..."):
        sample_length = sample["input_ids"].shape[1]
        for start_index in range(0, sample_length, context_length * 2):
            end_index = min(start_index + sample_length, sample_length - 1)

            subsample = {
                "input_ids": sample["input_ids"][:, start_index : end_index + 1],
                "attention_mask": sample["attention_mask"][:, start_index : end_index + 1],
            }

            # In case we are using torch.fx, we can not have optional inputs, and we have traced the model with past_key_values inputs, thus we need them here as well.
            if "past_key_values" in sample and isinstance(model, torch.fx.GraphModule):
                subsample["past_key_values"] = sample["past_key_values"]

            # Add BOS token.
            if tokenizer.bos_token_id is not None:
                subsample["input_ids"][:, 0] = tokenizer.bos_token_id

            use_accelerate = hasattr(model, "hf_device_map")
            if not use_accelerate or (use_accelerate and not hasattr(model, "_hf_hook")):
                device = next(model.parameters()).device
                for name, val in subsample.items():
                    subsample[name] = recursive_to_device(val, device)
            else:
                # In accelerate by default `io_same_device=True`, and here we want the of the model output on device.
                device = model._hf_hook.execution_device
                for name, val in subsample.items():
                    subsample[name] = recursive_to_device(val, device)

            lm_logits = model(**subsample)["logits"]

            reference_labels = subsample["input_ids"][:, context_length:]

            shift_logits = lm_logits[:, context_length - 1 : -1]

            # Fuse batch and sequence length dimensions.
            reference_labels = reference_labels.view(reference_labels.shape[-1])
            shift_logits = shift_logits.view(-1, shift_logits.shape[-1])

            loss = cross_entropy_loss(shift_logits, reference_labels)

            nlls.append(loss)

    ppl = torch.exp(torch.stack(nlls).mean())

    return ppl


def get_wikitext2(
    tokenizer: Any, seqlen: int, nsamples: int, split: str = "train", fuse_sequences: bool = True, seed: int = 42
):
    random.seed(seed)

    if split == "train":
        data = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    elif split == "validation":
        data = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

    if fuse_sequences:
        data = data.shuffle(seed=seed)
        # wikitext2 is too big.
        tokenized_data = tokenizer("\n\n".join(data["text"])[:100000], return_tensors="pt")

        dataset = []
        for _ in range(nsamples):
            i = random.randint(0, tokenized_data.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = tokenized_data.input_ids[:, i:j]
            attention_mask = torch.ones((1, seqlen), dtype=torch.int64)
            dataset.append({"input_ids": inp, "attention_mask": attention_mask})
    else:
        dataset = []
        with tqdm(total=nsamples) as pbar:
            while len(dataset) < nsamples:
                data_index = random.randint(0, len(data) - 1)

                enc = tokenizer(data[data_index]["text"], return_tensors="pt")

                if enc["input_ids"].shape[1] < seqlen:
                    continue

                start_idx = random.randint(0, enc["input_ids"].shape[1] - seqlen)
                end_idx = start_idx + seqlen - 1
                attention_mask = torch.ones((1, seqlen), dtype=torch.int64)
                input_ids = enc["input_ids"][:, start_idx : end_idx + 1]

                # Add BOS token.
                if tokenizer.bos_token_id is not None:
                    input_ids[:, 0] = tokenizer.bos_token_id

                dataset.append({"input_ids": input_ids, "attention_mask": attention_mask})
                pbar.update(1)

    return dataset


def get_c4(
    tokenizer: Any, seqlen: int, nsamples: int, split: str = "train", fuse_sequences: bool = True, seed: int = 42
):
    random.seed(seed)

    if split == "train":
        data = load_dataset("allenai/c4", split="train", data_files={"train": "en/c4-train.00000-of-01024.json.gz"})
    elif split == "validation":
        data = load_dataset(
            "allenai/c4",
            split="validation",
            data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"},
        )

    if fuse_sequences:
        data = data.shuffle(seed=seed)[:10000]  # c4 is too big.
        full_text = "\n\n".join(data["text"])
        tokenized_data = tokenizer(full_text, return_tensors="pt")

        dataset = []
        for _ in range(nsamples):
            i = random.randint(0, tokenized_data.input_ids.shape[1] - seqlen - 1)
            j = i + seqlen
            inp = tokenized_data.input_ids[:, i:j]
            attention_mask = torch.ones((1, seqlen), dtype=torch.int64)
            dataset.append({"input_ids": inp, "attention_mask": attention_mask})
    else:
        dataset = []
        with tqdm(total=nsamples) as pbar:
            while len(dataset) < nsamples:
                data_index = random.randint(0, len(data) - 1)

                enc = tokenizer(data[data_index]["text"], return_tensors="pt")

                if enc["input_ids"].shape[1] < seqlen:
                    continue

                start_idx = random.randint(0, enc["input_ids"].shape[1] - seqlen)
                end_idx = start_idx + seqlen - 1
                attention_mask = torch.ones((1, seqlen), dtype=torch.int64)
                input_ids = enc["input_ids"][:, start_idx : end_idx + 1]

                # Add BOS token.
                if tokenizer.eos_token_id is not None:
                    input_ids[:, 0] = tokenizer.eos_token_id

                dataset.append({"input_ids": input_ids, "attention_mask": attention_mask})
                pbar.update(1)

    return dataset


class DatasetToDevice(torch.utils.data.Dataset):
    def __init__(self, data: List, device: Optional[Union[str, torch.device]]):
        super().__init__()
        self.data = data
        self.device = device

    def __getitem__(self, idx):
        if self.device is not None:
            return {name: recursive_to_device(val, self.device) for name, val in self.data[idx].items()}
        else:
            return self.data[idx]

    def __len__(self):
        return len(self.data)


def get_dataset_for_model(
    model_name_or_path: str,
    qconfig: "BrevitasQuantizationConfig",
    dataset_name: str,
    tokenizer: Any,
    nsamples: int = 128,
    seqlen: int = 2048,
    seed: int = 0,
    split: str = "train",
    fuse_sequences: bool = True,
    device: Optional[Union[str, torch.device]] = None,
):
    """
    Get a dataset.

    Args:
        model_name_or_path (`str`):
            A local folder containing the model or the model hosted on the Hugging Face Hub.
        dataset_name (`str`):
            Dataset name. Available options are `['wikitext2', 'c4']`.
        tokenizer (`Any`):
            Tokenizer of the model
        nsamples (`int`, defaults to `128`):
            Number of samples
        seqlen (`int`, defaults to `2048`):
            The sequence length of the model
        seed (`int`, defaults to `0`):
            Seed
        split (`str`, defaults to `train`):
            Split of the dataset. Can be either "train" or "validation"
    Returns:
        `List[Dict[str,torch.LongTensor]]`: The tokenized dataset.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    get_dataset_map = {
        "wikitext2": get_wikitext2,
        "c4": get_c4,
    }
    if split not in ["train", "validation"]:
        raise ValueError(f"The split need to be 'train' or 'validation' but found {split}")
    if dataset_name not in get_dataset_map:
        raise ValueError(f"Expected a value in {list(get_dataset_map.keys())} but found {dataset_name}")
    get_dataset_fn = get_dataset_map[dataset_name]

    data = get_dataset_fn(
        tokenizer=tokenizer, nsamples=nsamples, seqlen=seqlen, split=split, fuse_sequences=fuse_sequences, seed=seed
    )

    # In case the dataset is loaded to be used with an fx.GraphModule, we need to add empty past_key_values inputs in the dataset.
    if qconfig.requires_fx_graph():
        config = AutoConfig.from_pretrained(model_name_or_path)

        normalized_config_class = NormalizedConfigManager.get_normalized_config_class(config.model_type)
        normalized_config = normalized_config_class(config)

        num_heads = normalized_config.num_attention_heads
        if hasattr(normalized_config, "num_key_value_heads"):
            num_kv_heads = normalized_config.num_key_value_heads
        else:
            num_kv_heads = num_heads
        head_dim = normalized_config.hidden_size // num_heads
        num_layers = normalized_config.num_layers

        for sample in data:
            sample["past_key_values"] = tuple(
                (
                    torch.zeros(1, num_kv_heads, 0, head_dim, device=sample["input_ids"].device),
                    torch.zeros(1, num_kv_heads, 0, head_dim, device=sample["input_ids"].device),
                )
                for _ in range(num_layers)
            )

    data = DatasetToDevice(data, device=device)

    return data
