from typing import Optional, List, Dict, Any, Union, Callable, AsyncGenerator
import logging
import torch
import os
from functools import lru_cache
from packaging import version
import json

from transformers import PreTrainedTokenizerBase
from dataclasses import dataclass, field

import jinja2
from jinja2.exceptions import TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment

from maga_transformer.openai.renderers.custom_renderer import CustomChatRenderer, \
    RendererParams, StreamResponseObject, RenderedInputs, RendererInfo
from maga_transformer.models.base_model import GenerateOutput
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, RoleEnum, \
    ChatCompletionRequest, ChatCompletionResponseStreamChoice, DeltaMessage, FinisheReason, UsageInfo
from maga_transformer.utils.multimodal_util import MMUrlType, MMPreprocessConfig


DEFAULT_CHAT_API_TEMPLATE = (
    "{% for message in messages %}"
        "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
    "{% endfor %}"
    "{% if add_generation_prompt %}"
        "{{ '<|im_start|>assistant\n' }}"
    "{% endif %}"
)

@dataclass
class PromptWithMMInput:
    prompt: str
    urls: List[str]
    mm_types: List[MMUrlType] = field(default_factory=list)
    preprocess_configs: List[MMPreprocessConfig] = field(default_factory=list)

# This class is designed to replace `PreTrainedTokenizerBase.apply_chat_template` functionality,
# providing more capability to customize the template.
# More specifically, this method allows template to use `functions` field, following openai chat api format.
# Besides that, other template elements is compatible with `PreTrainedTokenizerBase.apply_chat_template`.
class BasicRenderer(CustomChatRenderer):
    def __init__(self,
                 tokenizer: PreTrainedTokenizerBase,
                 renderer_params: RendererParams,
    ):
        super().__init__(tokenizer, renderer_params)

        if version.parse(jinja2.__version__) <= version.parse("3.0.0"):
            raise ImportError(
                "apply_chat_template requires jinja2>=3.0.0 to be installed. "
                "Your version is " f"{jinja2.__version__}."
            )

        self.add_generation_prompt = True
        self.chat_template = None
        self.special_tokens_map = {}

        try:
            self.chat_template = tokenizer.chat_template
            assert (self.chat_template != None)
        except:
            try:
                self.chat_template = tokenizer.default_chat_template
                assert (self.chat_template != None)
            except:
                logging.info(f"tokenizer {tokenizer} has no chat_template nor "
                              "default_chat_template attribute. Use default template.")
                self.chat_template = DEFAULT_CHAT_API_TEMPLATE
                self.add_extra_stop_words(["<|im_end|>"])

        try:
            if tokenizer.special_tokens_map != None:
                self.special_tokens_map = tokenizer.special_tokens_map
                for k, v in self.special_tokens_map.items():
                    logging.info(f"special token [{v}]({k}) added as stop words.")
                    if isinstance(v, str):
                        self.add_extra_stop_words([v])
                    elif isinstance(v, list):
                        self.add_extra_stop_words(v)
        except:
            pass

        try:
            if tokenizer.additional_special_tokens != None:
                logging.info(f"additional special tokens {tokenizer.additional_special_tokens}"
                             "added as stop words.")
                self.add_extra_stop_words(tokenizer.additional_special_tokens)
        except:
            pass

        # try:
        #     if tokenizer.added_tokens_decoder != None:
        #         for token_id, added_token in tokenizer.added_tokens_decoder.items():
        #             logging.info(f"added token [{token_id}]({added_token}) added as stop words.")
        #             self.add_extra_stop_word_ids([[token_id]])
        # except:
        #     pass

        logging.info(f"found chat template to use: {self.chat_template}")
        self.default_template_key = os.environ.get("DEFAULT_CHAT_TEMPLATE_KEY", "default")
        self.default_tool_use_template_key = os.environ.get("DEFAULT_TOOL_USE_TEMPLATE_KEY", "tool_use")
        self.compiled_template_map: Dict[str, jinja2.Template] = {}

        if isinstance(self.chat_template, dict):
            for key, template in self.chat_template.items():
                self.compiled_template_map[key] = self._compile_jinja_template(template)
        elif isinstance(self.chat_template, str):
            self.compiled_template_map[self.default_template_key] = self._compile_jinja_template(self.chat_template)
        else:
            raise Exception(f"chat template [{self.chat_template}] "
                            f"of type [{type(self.chat_template)}] is not supported.")

        if self.default_template_key not in self.compiled_template_map:
            raise Exception(f"default template key [{self.default_template_key}] not found "
                            f"in chat templates: [{self.compiled_template_map.keys()}]")
        if self.default_tool_use_template_key not in self.compiled_template_map:
            self.default_tool_use_template_key = self.default_template_key
        logging.info(f"compiled chat templates to use: {self.compiled_template_map.keys()}")

    def get_renderer_info(self) -> RendererInfo:
        renderer_info = super().get_renderer_info()
        renderer_info.template = self.chat_template
        return renderer_info

    @lru_cache
    def _compile_jinja_template(self, chat_template) -> jinja2.Template:

        def raise_exception(message):
            raise TemplateError(message)

        jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
        jinja_env.globals["raise_exception"] = raise_exception
        return jinja_env.from_string(chat_template)

    def _get_template(self, request: ChatCompletionRequest) -> jinja2.Template:
        if request.user_template:
            if request.template_key:
                raise ValueError("template_key and user_template can not be used together.")
            return self._compile_jinja_template(request.user_template)
        template_key = self.default_tool_use_template_key \
            if request.functions else self.default_template_key
        template_key = request.template_key or template_key
        return self.compiled_template_map[template_key]

    def render_chat(self, request: ChatCompletionRequest) -> RenderedInputs:
        template = self._get_template(request)
        request_dict = json.loads(request.model_dump_json(exclude_none=True))
        render_args = {
            "messages": request_dict['messages'],
            "json": json,
            "add_generation_prompt": self.add_generation_prompt,
        }
        render_args.update(self.special_tokens_map)
        # functions with none value may occur exception in llama3 template
        if request_dict.get('functions', None):
            render_args['functions'] = request_dict['functions']
        rendered = template.render(
            **render_args
        )
        logging.debug(f"request [{request.model_dump_json(indent=4)}] rendered string: [{rendered}]]")
        return RenderedInputs(input_ids=self.tokenizer.encode(rendered))
