import copy
import json
import re
import logging
import torch
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Union, Callable, Tuple, AsyncGenerator
import functools

from maga_transformer.models.base_model import GenerateOutput, GenerateOutputs
from maga_transformer.config.generate_config import GenerateConfig
from maga_transformer.openai.renderers.qwen_tool_renderer import QwenToolRenderer, QwenToolStreamStatus
from maga_transformer.tokenizer.tokenization_qwen import QWenTokenizer
from transformers import Qwen2Tokenizer
from maga_transformer.openai.api_datatype import ChatMessage, GPTFunctionDefinition, \
    ChatCompletionRequest, RoleEnum, FunctionCall, ChatCompletionResponseStreamChoice, \
    DeltaMessage, FinisheReason, UsageInfo, RendererInfo, PromptTokensDetails
from maga_transformer.openai.renderers.custom_renderer import CustomChatRenderer, RendererParams, \
    StreamResponseObject, RenderedInputs, StreamStatus, StreamStatusSync, OutputDelta, ThinkStatus
from maga_transformer.openai.renderers.basic_renderer import BasicRenderer
from maga_transformer.openai.renderer_factory_register import register_renderer
from maga_transformer.utils.word_util import get_stop_word_slices, truncate_response_with_stop_words, is_truncated

QwenTokenizerTypes = Union[QWenTokenizer, Qwen2Tokenizer]

TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}"""

REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs:

{tools_text}

Use the following format:

Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tools_name_text}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question

Begin!"""

DUMMY_THOUGHT = {
    "en": "\nThought: I now know the final answer.\nFinal answer: ",
    "zh": "\nThought: 我会作答了。\nFinal answer: ",
}

_TEXT_COMPLETION_CMD = object()

class QwenStreamStatus(StreamStatus):
    generating_function_call: bool = False
    total_output_string: str = ""

    def __init__(self, request: ChatCompletionRequest):
        super().__init__(request)

    def update_result(self):
        self.last_token_length = len(self.output_ids) - len(self.last_output_ids)
        self.last_output_ids = self.output_ids
        self.responded_string = self.total_output_string[: - len('\nAction:')]

    @property
    def responded_length(self):
        return len(self.responded_string)

    @property
    def output_length(self):
        return len(self.total_output_string)

    def check_stop_reason(self):
        if self.finish_reason == None:
            logging.debug(f"output [{self.responded_string}] found no stop reason! use stop as default.")
            self.finish_reason = FinisheReason.stop

class QwenStreamStatusSync(StreamStatusSync):
    generating_function_call: bool = False
    total_output_string: str = ""

    def __init__(self, request: ChatCompletionRequest):
        super().__init__(request)

    def update_result(self):
        self.responded_string = self.total_output_string[: - len('\nAction:')]

    @property
    def responded_length(self):
        return len(self.responded_string)

    @property
    def output_length(self):
        return len(self.total_output_string)

    def check_stop_reason(self):
        if self.finish_reason == None:
            logging.debug(f"output [{self.responded_string}] found no stop reason! use stop as default.")
            self.finish_reason = FinisheReason.stop

@dataclass
class ProcessedOutput:
    output_str: str
    output_token_length: int
    finish_reason: Optional[FinisheReason]

# TODO(wangyin): pass `max_window_size` to here.
def make_context(
    tokenizer: QwenTokenizerTypes,
    query: str,
    history: List[Tuple[str, str]] = [],
    system: str = "",
    max_window_size: int = 6144,
):
    history = copy.deepcopy(history)
    im_start, im_end = "<|im_start|>", "<|im_end|>"
    im_start_tokens = [tokenizer.im_start_id]
    im_end_tokens = [tokenizer.im_end_id]
    nl_tokens = tokenizer.encode("\n")

    def _tokenize_str(role, content):
        return f"{role}\n{content}", tokenizer.encode(
            role, allowed_special=set()
        ) + nl_tokens + tokenizer.encode(content, allowed_special=set())

    system_text, system_tokens_part = _tokenize_str("system", system)
    system_tokens = im_start_tokens + system_tokens_part + im_end_tokens

    raw_text = ""
    context_tokens = []

    for turn_query, turn_response in reversed(history):
        query_text, query_tokens_part = _tokenize_str("user", turn_query)
        query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
        response_text, response_tokens_part = _tokenize_str(
            "assistant", turn_response
        )
        response_tokens = im_start_tokens + response_tokens_part + im_end_tokens

        next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
        prev_chat = (
            f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
        )

        current_context_size = (
            len(system_tokens) + len(next_context_tokens) + len(context_tokens)
        )
        if current_context_size < max_window_size:
            context_tokens = next_context_tokens + context_tokens
            raw_text = prev_chat + raw_text
        else:
            break

    context_tokens = system_tokens + context_tokens
    raw_text = f"{im_start}{system_text}{im_end}" + raw_text
    context_tokens += (
        nl_tokens
        + im_start_tokens
        + _tokenize_str("user", query)[1]
        + im_end_tokens
        + nl_tokens
        + im_start_tokens
        + tokenizer.encode("assistant")
        + nl_tokens
    )
    raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
    return raw_text, context_tokens

class QwenRenderer(CustomChatRenderer):
    def __init__(self, tokenizer: QwenTokenizerTypes, renderer_params: RendererParams):
        super().__init__(tokenizer, renderer_params)
        self.add_extra_stop_word_ids([[37763, 367, 25], [151643]]) # Observation:

        self.qwen_tool_renderer = QwenToolRenderer(tokenizer, renderer_params)

        self.template_chat_renderer: Optional[BasicRenderer] = None
        try:
            if tokenizer.chat_template != None:
                logging.info(f"qwen model has chat_template [{tokenizer.chat_template}], "
                             "which will be used for non-function call dialogue.")
                self.template_chat_renderer = BasicRenderer(tokenizer, renderer_params)
        except AttributeError:
            pass

    def render_chat(self, request: ChatCompletionRequest) -> RenderedInputs:
        if request.tools:
            return self.qwen_tool_renderer.render_chat(request)

        if (self.template_chat_renderer != None) and \
            ((request.functions == None) or (len(request.functions) == 0)):
            return self.template_chat_renderer.render_chat(request)

        query, history, system = self.parse_messages(request.messages, request.functions)
        logging.debug(f"parsed query: {query}, history: {history}, system: {system}")
        input_ids = []
        if (query == _TEXT_COMPLETION_CMD):
            input_ids = self.text_complete_last_message(history)
        else:
            assert (isinstance(query, str))
            input_ids = make_context(self.tokenizer, query, history, system)[1]
        return RenderedInputs(input_ids=input_ids)

    def text_complete_last_message(self, history):
        im_start = "<|im_start|>"
        im_end = "<|im_end|>"
        prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}"
        for i, (query, response) in enumerate(history):
            query = query.lstrip("\n").rstrip()
            response = response.lstrip("\n").rstrip()
            prompt += f"\n{im_start}user\n{query}{im_end}"
            prompt += f"\n{im_start}assistant\n{response}{im_end}"
        prompt = prompt[: -len(im_end)]

        return self.tokenizer.encode(prompt)

    def parse_messages(
            self,
            messages: List[ChatMessage],
            functions: Optional[List[GPTFunctionDefinition]] = None
    ):
        if all(m.role != "user" for m in messages):
            raise ValueError("At least one message must be from user.")

        messages = copy.deepcopy(messages)
        if messages[0].role == 'system':
            system = messages.pop(0).content.lstrip('\n').rstrip()
        else:
            system = 'You are a helpful assistant.'

        if functions:
            tools_text = []
            tools_name_text = []
            for func_info in functions:
                name = func_info.name
                name_m = func_info.name_for_model or name
                name_h = func_info.name_for_human or name
                desc = func_info.description
                desc_m = func_info.description_for_model or desc
                tool = TOOL_DESC.format(
                    name_for_model=name_m,
                    name_for_human=name_h,
                    # Hint: You can add the following format requirements in description:
                    #   "Format the arguments as a JSON object."
                    #   "Enclose the code within triple backticks (`) at the beginning and end of the code."
                    description_for_model=desc_m,
                    parameters=json.dumps(func_info.parameters, ensure_ascii=False),
                )
                tools_text.append(tool)
                tools_name_text.append(name_m)
            tools_text = "\n\n".join(tools_text)
            tools_name_text = ", ".join(tools_name_text)
            instruction = (REACT_INSTRUCTION.format(
                tools_text=tools_text,
                tools_name_text=tools_name_text,
            ).lstrip('\n').rstrip())
        else:
            instruction = ''

        messages_with_fncall = messages
        messages = []
        for m_idx, m in enumerate(messages_with_fncall):
            role, content, func_call = m.role, m.content, m.function_call
            content = content or ""
            content = content.lstrip("\n").rstrip()
            if role == "function":
                if (len(messages) == 0) or (messages[-1].role != "assistant"):
                    raise ValueError(f"Invalid request: Expecting role assistant before role function.")
                messages[-1].content += f'\nObservation: {content}'
                if m_idx == len(messages_with_fncall) - 1:
                    # add a prefix for text completion
                    messages[-1].content += '\nThought:'
            elif role == 'assistant':
                if len(messages) == 0:
                    raise ValueError(f"Invalid request: Expecting role user before role assistant.")
                if func_call is None:
                    if functions:
                        content = f'Thought: I now know the final answer.\nFinal Answer: {content}'
                else:
                    f_name, f_args = func_call.name, func_call.arguments
                    if not content.startswith('Thought:'):
                        content = f'Thought: {content}'
                    content = f'{content}\nAction: {f_name}\nAction Input: {f_args}'
                if messages[-1].role == 'user':
                    messages.append(
                        ChatMessage(role=RoleEnum.assistant, content=content.lstrip('\n').rstrip())
                    )
                else:
                    messages[-1].content += '\n' + content
            elif role == 'user':
                messages.append(
                    ChatMessage(role='user',content=content.lstrip('\n').rstrip()))
            else:
                raise ValueError(f"Invalid request: Incorrect role {role}.")

        query = _TEXT_COMPLETION_CMD
        if messages[-1].role == 'user':
            query = messages[-1].content
            messages = messages[:-1]

        history = []  # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)]
        for i in range(0, len(messages), 2):
            if messages[i].role == 'user' and messages[i + 1].role == 'assistant':
                usr_msg = messages[i].content.lstrip('\n').rstrip()
                bot_msg = messages[i + 1].content.lstrip('\n').rstrip()
                if instruction and (i == len(messages) - 2):
                    usr_msg = f'{instruction}\n\nQuestion: {usr_msg}'
                    instruction = ''
                history.append([usr_msg, bot_msg])
            else:
                raise ValueError(
                    "Invalid request: Expecting exactly one user (or function) role before every assistant role."
                )
        if instruction:
            assert query is not _TEXT_COMPLETION_CMD
            query = f'{instruction}\n\nQuestion: {query}'
        return query, history, system

    def _parse_function_response(self, response: str) -> Optional[DeltaMessage]:
        func_name, func_args = "", ""
        i = response.rfind("\nAction:")
        j = response.rfind("\nAction Input:")
        k = response.rfind("\nObservation:")
        if 0 <= i < j:  # If the text has `Action` and `Action input`,
            if k < j:  # but does not contain `Observation`,
                # then it is likely that `Observation` is omitted by the LLM,
                # because the output text may have discarded the stop word.
                response = response.rstrip() + "\nObservation:"  # Add it back.
            k = response.rfind("\nObservation:")
            func_name = response[i + len("\nAction:") : j].strip()
            func_args = response[j + len("\nAction Input:") : k].strip()
        logging.info(f"parsed function from response: [{response}]: {func_name}, {func_args}")
        if func_name:
            return DeltaMessage(
                content=response[:i],
                function_call=FunctionCall(name=func_name, arguments=func_args),
            )
        return None
        # z = response.rfind("\nFinal Answer: ")
        # if z >= 0:
        #     response = response[z + len("\nFinal Answer: ") :]

    async def _update_single_status(self, status: StreamStatus, 
                                    output: GenerateOutput, max_new_tokens: int,
                                    stop_words_str: List[str], stop_word_slice_list: List[str], is_streaming: bool) -> OutputDelta:
        if status.request.tools:
            return await self.qwen_tool_renderer._update_single_status(
                status, output, max_new_tokens, stop_words_str, stop_word_slice_list, is_streaming)

        if not isinstance(status, QwenStreamStatus):
            return await super()._update_single_status(status, output, max_new_tokens, stop_words_str, stop_word_slice_list, is_streaming)
        if status.finish_reason != None:
            return await self._create_empty_delta(status.output.aux_info)
        status.update_output(output,
                             self._clean_output_ids,
                             functools.partial(self._check_finish_reason, max_new_tokens=max_new_tokens),
                             self._remove_stop_word_ids)
        status.total_output_string = self.tokenizer.decode(status.output_ids).strip()
        if (len(status.total_output_string)) and (u'\uFFFD' == status.total_output_string[-1]):
            return await self._create_empty_delta(output.aux_info)
            # For some tokenizers (e.g. ChatGLM), decode a single token differs from decode a list of tokens.
        if (status.total_output_string.endswith("\nAction:")):
            status.generating_function_call = True
            return await self._create_empty_delta(output.aux_info)
        if (status.generating_function_call):
            return await self._create_empty_delta(output.aux_info)
        if is_truncated(status.total_output_string, stop_words_str, is_streaming):
            status.finish_reason = FinisheReason.stop
            return await self._create_empty_delta(output.aux_info)
        if (len(status.total_output_string) > status.responded_length + len('\nAction:')):
            status.delta_output_string = status.total_output_string[status.responded_length : status.output_length - len('\nAction:')]
            if is_truncated(status.delta_output_string, stop_word_slice_list, is_streaming):
                return await self._create_empty_delta(output.aux_info)
            else:
                status.update_result()
                return OutputDelta(
                    status.delta_output_string,
                    await self._generate_log_probs(status, output),
                    status.input_token_length,
                    status.output_token_length,
                    status.reuse_length)
        return await self._create_empty_delta(output.aux_info)

    #override
    async def _create_status_list(self, n: int, request: ChatCompletionRequest) -> List[StreamStatus]:
        if request.tools:
            return [QwenToolStreamStatus(request) for _ in range(n)]
        if request.functions and (len(request.functions) > 0):
            return [QwenStreamStatus(request) for _ in range(n)]
        else:
            return [StreamStatus(request) for _ in range(n)]

    #override
    async def _flush_buffer(self, buffer_list: List[StreamStatus], stop_words_str: List[str], is_streaming: bool, think_status: ThinkStatus):
        if buffer_list[0].request.tools:
            return await self.qwen_tool_renderer._flush_buffer(buffer_list, stop_words_str, is_streaming, think_status)

        if (not isinstance(buffer_list[0], QwenStreamStatus)):
            return await super()._flush_buffer(buffer_list, stop_words_str, is_streaming, think_status)
        output_items: List[OutputDelta] = []
        for status in buffer_list:
            if status.generating_function_call:
                function_message = self._parse_function_response(status.total_output_string[status.responded_length:])
                if (function_message == None):
                    logging.warning(f"output [{status.total_output_string}] failed to parse function from [{status.responded_length}]. "
                                    "regarded as normal output.")
                    function_message = ""
                else:
                    status.finish_reason = FinisheReason.function_call
                output_items.append(OutputDelta(
                    function_message,
                    await self._generate_log_probs(status, status.output),
                    status.input_token_length,
                    status.output_token_length,
                    status.reuse_length))
            else:
                trunc_string = truncate_response_with_stop_words(status.total_output_string[status.responded_length:], stop_words_str, is_streaming)
                output_items.append(OutputDelta(
                    trunc_string,
                    await self._generate_log_probs(status, status.output),
                    status.input_token_length,
                    status.output_token_length,
                    status.reuse_length))
        return await self._generate_stream_response(output_items, think_status)

    #override
    def _update_single_status_sync(self,
                                   status: StreamStatusSync,
                                   input_len, # output.aux_info
                                   output_len, # output.aux_info
                                   reuse_len, # output.aux_info
                                   all_probs: torch.Tensor,
                                   output_ids: torch.Tensor,
                                   max_new_tokens: int,
                                   stop_words_str: List[str],
                                   stop_word_slice_list: List[str],
                                   is_streaming: bool) -> OutputDelta:
        # function call is disabled when logprobs is required.
        if not isinstance(status, QwenStreamStatusSync):
            return super()._update_single_status_sync(status,
                                                       input_len, output_len, reuse_len,
                                                       all_probs, output_ids,
                                                       max_new_tokens, stop_words_str,
                                                       stop_word_slice_list,
                                                       is_streaming)
        if status.finish_reason != None:
            return self._create_empty_delta_sync(input_len, output_len, reuse_len)
        status.update_output_sync(output_ids, input_len,
                                  self._clean_output_ids,
                                  functools.partial(self._check_finish_reason, max_new_tokens=max_new_tokens),
                                  self._remove_stop_word_ids)
        status.total_output_string = self.tokenizer.decode(status.output_ids).strip()
        if (len(status.total_output_string)) and (u'\uFFFD' == status.total_output_string[-1]):
            return self._create_empty_delta_sync(input_len, output_len, reuse_len)
            # For some tokenizers (e.g. ChatGLM), decode a single token differs from decode a list of tokens.
        if (status.total_output_string.endswith("\nAction:")):
            status.generating_function_call = True
            return self._create_empty_delta_sync(input_len, output_len, reuse_len)
        if (status.generating_function_call):
            return self._create_empty_delta_sync(input_len, output_len, reuse_len)
        if is_truncated(status.total_output_string, stop_words_str, is_streaming):
            status.finish_reason = FinisheReason.stop
            return self._create_empty_delta_sync(input_len, output_len, reuse_len)
        if (len(status.total_output_string) > status.responded_length + len('\nAction:')):
            status.delta_output_string = status.total_output_string[status.responded_length : status.output_length - len('\nAction:')]
            if is_truncated(status.delta_output_string, stop_word_slice_list, is_streaming):
                return self._create_empty_delta_sync(input_len, output_len, reuse_len)
            else:
                status.update_result()
                return OutputDelta(
                    output_str=status.delta_output_string,
                    logprobs=self._generate_log_probs_sync(status, all_probs, output_ids),
                    input_length=input_len,
                    output_length=output_len,
                    reuse_length=reuse_len)
        return self._create_empty_delta_sync(input_len, output_len, reuse_len)

    #override
    def _create_status_list_sync(self, n: int, body: str) -> List[StreamStatusSync]:
        request = self.getRequest(body)
        if request.logprobs:
            return [StreamStatusSync(request) for _ in range(n)]
        else:
            return [QwenStreamStatusSync(request) for _ in range(n)]

    #override
    def _flush_buffer_sync(self,
                           buffer_list: List[StreamStatusSync],
                           input_len_list, output_len_list, reuse_len_list,
                           all_probs_list, output_ids_list,
                           stop_words_str: List[str],
                           is_streaming: bool):
        if (not isinstance(buffer_list[0], QwenStreamStatusSync)):
            return super()._flush_buffer_sync(buffer_list,
                                              input_len_list, output_len_list, reuse_len_list,
                                              all_probs_list, output_ids_list,
                                              stop_words_str,
                                              is_streaming)
        output_items: List[OutputDelta] = []
        for status, input_len, output_len, reuse_len, all_probs, output_ids in zip(
                buffer_list,
                input_len_list, output_len_list, reuse_len_list,
                all_probs_list, output_ids_list
                ):
            if status.generating_function_call:
                function_message = self._parse_function_response(status.total_output_string[status.responded_length:])
                if (function_message == None):
                    logging.warning(f"output [{status.total_output_string}] failed to parse function from [{status.responded_length}]. "
                                    "regarded as normal output.")
                    function_message = ""
                else:
                    status.finish_reason = FinisheReason.function_call
                output_items.append(OutputDelta(
                    function_message,
                    self._generate_log_probs_sync(status, all_probs, output_ids),
                    input_len,
                    output_len,
                    reuse_len))
            else:
                trunc_string = truncate_response_with_stop_words(status.total_output_string[status.responded_length:], stop_words_str, is_streaming)
                output_items.append(OutputDelta(
                    trunc_string,
                    self._generate_log_probs_sync(status, all_probs, output_ids),
                    input_len,
                    output_len,
                    reuse_len))
        return self._generate_stream_response_sync(output_items)

    def get_renderer_info(self) -> RendererInfo:
        renderer_info = super().get_renderer_info()
        if self.template_chat_renderer:
            renderer_info.template = self.template_chat_renderer.chat_template
        return renderer_info

register_renderer('qwen', QwenRenderer)
register_renderer('qwen_7b', QwenRenderer)
register_renderer('qwen_13b', QwenRenderer)
register_renderer('qwen_1b8', QwenRenderer)
register_renderer('qwen_2', QwenRenderer)
register_renderer('qwen_2_moe', QwenRenderer)
