# Copyright (c) 2023 Uber Technologies, Inc.
#
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
#
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.

import copy
import difflib
import logging
import multiprocessing
import re
from typing import Dict, List, Optional, Tuple

import attr
import toml
from polyglot_piranha import (PiranhaArguments, PiranhaOutputSummary, Rule,
                              RuleGraph, execute_piranha)

from piranha_playground.rule_inference.controller import Controller
from piranha_playground.rule_inference.graph_parser import GraphParser
from piranha_playground.rule_inference.piranha_chat import (
    PiranhaChatException, PiranhaGPTChat)
from piranha_playground.rule_inference.rule_application import \
    run_piranha_with_timeout
from piranha_playground.rule_inference.static_inference import (Inference,
                                                                QueryWriter)
from piranha_playground.rule_inference.template_parser import TemplateParser
from piranha_playground.rule_inference.utils.logger_formatter import \
    CustomFormatter
from piranha_playground.rule_inference.utils.node_utils import NodeUtils
from piranha_playground.rule_inference.utils.pretty_toml import PrettyTOML
from piranha_playground.rule_inference.utils.rule_utils import (RawRule,
                                                                RawRuleGraph)

logger = logging.getLogger("PiranhaChat")
logger.setLevel(logging.DEBUG)

ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(CustomFormatter())
logger.addHandler(ch)


class PiranhaAgentError(Exception):
    pass


@attr.s
class PiranhaAgent:
    """
    An agent that uses OpenAI's chat models for inferring Piranha rules.
    The agent takes pairs of source and target codes, finds a transformation rule between them,
    and validates the rule's effectiveness by testing if the rule can refactor the source code into the target code.

    :ivar source_code: The source code to refactor
    :ivar target_code: The target code we want to achieve after refactoring
    :ivar language: The programming language of the source code (default is "java")
    :ivar hints: Any hints or information that might help in inferring the rules
    :ivar explanation: Explanation of the rule generated by GPT
    """

    source_code = attr.ib(type=str)
    target_code = attr.ib(type=str)
    language = attr.ib(default="java")
    hints = attr.ib(default="")
    explanation = attr.ib(default=None)

    def infer_rules_statically(self) -> str:
        """This function creates the first pass of the rule inference process.
        It statically infers rules from the example code and returns a TOML representation of the rule graph.

        :return: str, string containing the rule in TOML format
        """

        parser = TemplateParser(self.language)
        source_tree = parser.get_tree_from_code(self.source_code)
        target_tree = parser.get_tree_from_code(self.target_code)

        rules = {}
        finder = GraphParser(source_tree, target_tree)
        pairs = finder.parse_templates()
        for comment_name, (nodes_before, nodes_after) in pairs.items():
            inference_engine = Inference(nodes_before, nodes_after, parser.template_holes)
            rule = inference_engine.static_infer()
            rules[comment_name] = rule

        # build a dict using finder.edges but with the rule names from rule_names
        edges = {
            rules[from_name].name: [rules[to_name].name for to_name in to_names]
            for from_name, to_names in finder.edges.items()
        }
        edges = [
            {"from": k, "to": v, "scope": "File"} for k, v in edges.items() if v != []
        ]
        graph = RawRuleGraph(list(rules.values()), edges)
        return graph.to_toml()

    def create_chats(self, rules: str) -> List[PiranhaGPTChat]:
        """
        Prepare the data for interaction with the AI model.

        :param rules: Statically inferred rules in TOML format
        :type rules: str
        :return: List of chat interactions with information necessary for AI model.
        :rtype: List[PiranhaGPTChat]
        """


        parser = TemplateParser(self.language)

        source_tree = parser.get_tree_from_code(self.source_code, remove_comments=True)
        target_tree = parser.get_tree_from_code(self.target_code, remove_comments=True)

        source_tree_sexpr = NodeUtils.generate_sexpr(source_tree.root_node, 0)
        target_tree_sexpr = NodeUtils.generate_sexpr(target_tree.root_node, 0)
        # Create diff between source and target code using difflib
        diff = list(
            difflib.unified_diff(
                self.source_code.splitlines(), self.target_code.splitlines()
            )
        )
        diff = "\n".join(diff)
        # diff = self.append_diff_information(diff, source_tree, target_tree)
        # Cleanup source
        prompt_holes = {
            "source_code": self.source_code,
            "source_tree": source_tree_sexpr,
            "target_tree": target_tree_sexpr,
            "diff": diff,
            "rules": rules,
            "hints": self.hints,
        }
        # Number of Chat interactions to have with the model
        n_samples = 5
        chat_interactions = [
            PiranhaGPTChat(holes=prompt_holes) for _ in range(n_samples)
        ]
        try:
            first_round = chat_interactions[0].get_completion(n_samples=n_samples)
        except PiranhaChatException as e:
            logger.debug(
                f"Chat completion failed with {e}. Trying again with a new chat...\n"
            )
            return []
        for i, response in enumerate(first_round):
            # Hack to prevent running the prompt multiple times (it's the same for all samples)
            # It is cheaper just to sample OpenAI API
            chat_interactions[i].append_system_message(response)
        return chat_interactions

    def get_explanation(self):
        return self.explanation

    def general_improvement(self, chat_interactions: List[PiranhaGPTChat]) -> str:
        """
        Find a rule generated by chat models that complies with the user specified example.

        :param chat_interactions: List of chat sessions for the inference engine to use
        :type chat_interactions: List[PiranhaGPTChat]
        :return: The rule inferred from GPT
        :rtype: str
        :raises PiranhaAgentError: If the agent fails to generate a rule after 10 rounds of interaction with GPT-4.
        """
        max_rounds = 10
        for i in range(max_rounds):
            for chat in chat_interactions:
                try:
                    completion = chat.get_model_response()
                    _, toml_block, explanation = self.validate_rule(completion)
                    self.explanation = explanation
                    return toml_block
                except PiranhaAgentError as e:
                    logger.debug(
                        f"GPT-4 failed to generate a rule. Following up the next round with {e}. Trying again...\n"
                    )
                    chat.append_user_followup(str(e))
                except PiranhaChatException as e:
                    logger.debug(
                        f"Chat completion failed with {e}. Trying again with a new chat...\n"
                    )
        raise PiranhaAgentError(
            f"Failed to generate a rule after {max_rounds} rounds of interaction with GPT-4.",
        )

    def validate_rule(self, completion) -> Tuple[str, str, str]:
        """
        Tests if the inferred rule can transform the source code into the target code.

        :param completion: Inferred rule from the model
        :type completion: str
        :return: A tuple containing the file name, TOML block, and the explanation
        :rtype: Tuple[str, str, str]
        """
        pattern = r"```toml(?!md)(.*?)```"
        logger.debug(f"Completion\n: {completion}")
        # Extract all toml block contents
        toml_blocks = re.findall(pattern, completion, re.DOTALL)
        if not toml_blocks:
            raise PiranhaAgentError(
                "No TOML block provided in the expected output format. "
                "Please provide a TOML block with the rule. ```toml ... ```"
            )

        pattern = r"```md(.*?)```"
        explanation = re.findall(pattern, completion, re.DOTALL)

        if not explanation:
            raise PiranhaAgentError(
                "No explanation provided in the expected output format. "
                "Please provide an explanation as a markdown block. ```md ... ```"
            )

        try:
            toml_block = (
                toml_blocks[0].replace("parenthesized_expression", "condition").strip()
            )
            logger.debug(f"Generated rule: {toml_block}")
            toml_dict = toml.loads(toml_block)
        except Exception as e:
            raise PiranhaAgentError(
                f"Could not create Piranha rule. The TOML block is not valid: {e}. "
            )

        refactored_code = self.run_piranha(toml_dict)
        if not refactored_code:
            raise PiranhaAgentError(
                "Piranha did not generate any refactored code. Either the query or the filters are incorrect. "
            )
        if NodeUtils.normalize_code(refactored_code) != NodeUtils.normalize_code(
                self.target_code
        ):
            raise PiranhaAgentError(
                f"The rule produced wrong code!!! "
                f"Expected:\n{self.target_code}\n\n but got:\n{refactored_code}\n\n"
            )
        pattern = r"<file_name_start>(.*?)<file_name_end>"
        file_names = re.findall(pattern, completion, re.DOTALL)
        file_name = file_names[0] if file_names else "rule.toml"
        return file_name, toml_block, explanation[0]

    def run_piranha(self, toml_dict) -> str:
        """
        Runs the inferred rule graph by applying it to the source code using Piranha.

        :param toml_dict: Inferred rules in TOML format
        :type toml_dict: dict
        :return: Refactored code as a result of the rule application
        :rtype: str
        """
        rules = toml_dict.get("rules", [])
        if not rules:
            raise PiranhaAgentError("TOML does not include any rule specifications.")
        try:
            raw_graph = RawRuleGraph.from_toml(toml_dict)
            logger.debug(f"Raw graph: {raw_graph.to_toml()}")

            res, success = run_piranha_with_timeout(
                self.source_code, self.language, raw_graph, timeout=5
            )

            if not success:
                if "QueryError" in res:
                    raise PiranhaAgentError(
                        f"One of the provided queries is not valid {res}. "
                        f"Do not use nodes you cannot see in the tree representation. "
                        f"Make sure you parenthesis are balanced."
                    )
                raise PiranhaAgentError(f"Piranha failed to execute: {res}.")
            return res

        except multiprocessing.context.TimeoutError:
            raise PiranhaAgentError(
                "Piranha in infinite loop. Please add a filter or constraint the query. "
                "Remember you can only constraint queries with #eq, #not-eq, #match. "
                "Otherwise you need to use a [[rules.filters]] with contains or not_contains."
            )

    def improve_rules(self, rules: List[RawRule], task: str, chat: PiranhaGPTChat) -> Tuple[List[RawRule], List[str]]:
        """
        Processes each rule and decides whether it should be improved.

        :param rules: Rules to be processed
        :param task: Task for which rules are being processed
        :param chat: Chat interaction object
        :return: A tuple containing a list of updated rules and a list of explanations
        """
        updated_rules = []
        explanations = []
        for rule in rules:
            updated_rule, explanation = self.improve_single_rule(rule, task, chat)
            updated_rules.append(updated_rule)
            explanations.append(explanation)
        return updated_rules, explanations

    def improve_single_rule(self, rule: RawRule, task: str, chat: PiranhaGPTChat) -> Tuple[RawRule, str]:
        """
        Processes a single rule and decides whether it should be improved.

        :param rule: Rule to be processed
        :param task: Task for which the rule is being processed
        :param chat: Chat interaction object
        :return: A tuple containing the updated rule and an explanation
        """
        rule_str = rule.to_toml()
        controller = Controller(chat)
        if controller.should_improve_rule(task, rule_str):
            option = controller.get_option_for_improvement(rule_str)
            if option == "add filter":
                return self.add_filter(task, rule, chat)
        return rule, ''

    def validate_improved_rules(self, updated_rules: List[RawRule], explanations: List[str]) -> str:
        """
        Validates the updated rules and generates an explanation.

        :param updated_rules: A list of updated rules
        :param explanations: A list of explanations for the rules
        :return: A validation string
        """
        rule_block = "\n".join(
            [rule.to_toml() for rule in updated_rules]
        )
        explanation_block = "\n".join(explanations)
        validation = self.validate_rule(
            f"<file_name_start>rules.toml<file_name_end> ```toml\n{rule_block}\n``` ```md\n{explanation_block}\n```"
        )
        self.explanation = "\n".join(explanations)
        return validation[1]

    def improve_rule_graph(self, task: str, rules: str, option: str) -> str:
        """
        Improves the rule by adding a filter to it.

        :param task: Description of what you would like to do
        :param rules: Rules to improve
        :param option: User or general improvements
        :return: The improved rule
        :raises PiranhaAgentError: If unable to improve the rule
        """

        try:
            chat_interactions = self.create_chats(rules)
        except BaseException as e:
            logger.debug(f"Failed to create chat: {e}")
            raise PiranhaAgentError(str(e)) from e

        if option == "general":
            return self.general_improvement(chat_interactions)

        graph = RawRuleGraph.from_toml(toml.loads(rules))
        for chat in chat_interactions:
            try:
                updated_rules, explanations = self.improve_rules(graph.rules, task, chat)
                return self.validate_improved_rules(updated_rules, explanations)
            except Exception as e:
                logger.debug(
                    f"GPT-4 failed to generate a rule. Following up the next round with {e}. Trying again...\n"
                )
                chat.append_user_followup(str(e))

        logger.debug(f"Unable to improve rule {rules}.")
        raise PiranhaAgentError("Unable to improve rule.")

    def add_filter(self, desc: str, rule: RawRule, chat: PiranhaGPTChat) -> Tuple[RawRule, str]:
        """
        Adds a filter to the rule that encloses the nodes of the rule.

        :param desc: Description of what you would like to do
        :param rule: Rule to add a filter to
        :param chat: Chat interactions with information necessary for AI model
        :return: A tuple containing the rule with the added filter and its explanation
        :rtype: Tuple[dict, str]
        """

        query = rule.query
        source_tree = self.get_tree_from_code(self.source_code)
        tree_sitter_q = self.tree_sitter_language.query(query)
        captures = tree_sitter_q.captures(source_tree.root_node)
        captures = NodeUtils.get_smallest_nonoverlapping_set([c[0] for c in captures])

        parents = []
        for node in captures:
            while node:
                parents.append(node)
                node = node.parent

        enclosing_nodes = parents
        enclosing_options = ""

        for i, node in enumerate(enclosing_nodes):
            qw = QueryWriter([node])
            query = qw.write(simplify=True)
            enclosing_options += f"\n\n=== Option {i} ===\n\n"
            enclosing_options += f'enclosing_node = """{query}"""\n'

        # Get the nodes that can be used as enclosing node for the rules
        chat.append_improve_request(
            desc,
            rule.to_toml(),
            enclosing_options,
        )
        completion = chat.get_model_response()
        pattern = r"```toml(?!md)(.*?)```"
        # Extract all toml block contents
        toml_blocks = re.findall(pattern, completion, re.DOTALL)
        toml_dict = toml.loads(toml_blocks[0])

        pattern = r"```md(.*?)```"
        explanation = re.findall(pattern, completion, re.DOTALL)

        return RawRule.from_toml(toml_dict["rules"]), explanation[0]
