# Copyright (c) 2023 Uber Technologies, Inc.
#
# <p>Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file
# except in compliance with the License. You may obtain a copy of the License at
# <p>http://www.apache.org/licenses/LICENSE-2.0
#
# <p>Unless required by applicable law or agreed to in writing, software distributed under the
# License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Dict, List

import attr
from comby import Comby
from tree_sitter import Node, TreeCursor

from piranha_playground.rule_inference.utils.node_utils import NodeUtils
from piranha_playground.rule_inference.utils.rule_utils import RawRule


@attr.s
class ChildProcessingStrategy(ABC):
    """
    Abstract Base Class to define a strategy to process a named child node.
    """

    query_writer = attr.ib(type="QueryWriter")

    @abstractmethod
    def process_child(self, cursor: TreeCursor, depth: int) -> str:
        """This method decides a child of a node should be handled when writing a query.
        Either expand the query for the child, simply capture it without expanding, etc.
        :param cursor: The cursor pointing to the child node.
        :param depth: The depth of the child node in the tree.
        """
        pass


@attr.s
class SimplifyStrategy(ChildProcessingStrategy):
    """
    Strategy to simplify the processing of a named child node.
    """

    def process_child(self, cursor: TreeCursor, depth: int) -> str:
        """Simply capture the child node without expanding it."""
        s_exp = "\n"
        child_node: Node = cursor.node

        node_rep = child_node.text.decode("utf8")
        if node_rep in self.query_writer.template_holes:
            node_types = self.query_writer.template_holes[node_rep]
            alternations = " ".join([f"({node_type})" for node_type in node_types])
            s_exp += (
                " " * (depth + 1) + f"[{alternations}] @tag{self.query_writer.count}n"
            )
        else:
            child_type = child_node.type
            s_exp += (
                " " * (depth + 1) + f"({child_type}) @tag{self.query_writer.count}n"
            )

            if child_node.child_count == 0:
                self.query_writer.query_ctrs.append(
                    f"(#eq? @tag{self.query_writer.count}n \"{child_node.text.decode('utf8')}\")"
                )
        self.query_writer.count += 1
        return s_exp


@attr.s
class RegularStrategy(ChildProcessingStrategy):
    """
    Strategy for the regular processing of a named child node.
    """

    def process_child(self, cursor: TreeCursor, depth: int) -> str:
        """Write the query as if the child node was the root of the tree."""
        prefix = (
            f"{cursor.current_field_name()}: " if cursor.current_field_name() else ""
        )
        return "\n" + self.query_writer.write_query(
            cursor.node, depth + 1, prefix, simplify=False
        )


@attr.s
class QueryWriter:
    """
    Class to represent a query writer for nodes.

    :ivar seq_nodes: Sequence of nodes for which the query will be written.
    :ivar template_holes: A dictionary to keep track of the nodes that will be replaced by a capture group.
    :ivar capture_groups: A dictionary to keep track of the nodes that will be captured in the query.
    :ivar count: A counter for naming the capture groups.
    :ivar query_str: The Comby pattern that represents the query.
    :ivar query_ctrs: List of constraints for the Comby pattern.
    :ivar outer_most_node: Represents the node that is currently the furthest from the root.
    """

    seq_nodes = attr.ib(type=list)
    template_holes = attr.ib(type=List)
    capture_groups = attr.ib(default=attr.Factory(dict))

    count = attr.ib(default=0)
    query_str = attr.ib(default="")
    query_ctrs = attr.ib(default=attr.Factory(list))
    outer_most_node = attr.ib(default=None)
    strategy = attr.ib(default=None)

    def write(self, simplify=False):
        """
        Get textual representation of the sequence.
        Find for each named child of source_node, can we replace it with its respective target group.

        :param simplify: If True, simplify the query.
        :return: The query string
        """

        self.strategy = SimplifyStrategy(self) if simplify else RegularStrategy(self)
        self.query_str = ""
        self.query_ctrs = []

        node_queries = [
            self.write_query(node, simplify=simplify) for node in self.seq_nodes
        ]

        self.query_str = ".".join(node_queries) + "\n" + "\n".join(self.query_ctrs)
        self.query_str = f"({self.query_str})"

        return self.query_str

    def write_query(self, node: Node, depth=0, prefix="", simplify=False):
        """
        Write a query for a given node, considering its depth and prefix.

        :param node: The node for which the query will be written.
        :param depth: The current depth of the node.
        :param prefix: Prefix for the current node.
        :param simplify: If True, simplify the query.
        :return: The query string for this node.
        """

        indent = " " * depth
        cursor: TreeCursor = node.walk()

        self.count += 1
        node_repr = node.text.decode("utf8")
        node_name = f"@tag{self.count}n"

        if node_repr in self.template_holes:
            # Fixed node
            node_types = self.template_holes[node_repr]
            if len(node_types) == 1:
                s_exp = indent + f"{prefix}({node_types[0]})"
            else:
                alternations = " ".join([f"({node_type})" for node_type in node_types])
                s_exp = indent + f"{prefix}[{alternations}]"

        else:
            # Regular node
            node_type = node.type
            s_exp = indent + f"{prefix}({node_type} "
            next_child = cursor.goto_first_child()
            visited = 0
            while next_child:
                if cursor.node.is_named:
                    visited += 1
                    s_exp += self.strategy.process_child(cursor, depth)
                next_child = cursor.goto_next_sibling()

            # if the node is an identifier, add it to eq constraints
            if visited == 0:
                text = node.text.decode("utf8").replace("\n", " ")
                self.query_ctrs.append(f'(#eq? {node_name} "{text}")')
            s_exp += f")"

        self.capture_groups[node_name] = node
        self.outer_most_node = node_name
        return s_exp + f" {node_name}"

    def simplify_query(self, capture_group):
        """
        Simplify a query removing all the children of capture_group and replacing it with a wildcard node.
        This should be replaced with Piranha at some point.

        :param capture_group: The capture group to simplify.
        """

        comby = Comby()
        match = f"(:[[node_name]] :[_]) {capture_group}"
        rewrite = f"(:[[node_name]]) {capture_group}"
        self.query_str = comby.rewrite(self.query_str, match, rewrite)

        # Now for every child of capture_group, we need to remove equality checks from the query
        stack = [self.capture_groups[capture_group]]
        while stack:
            first = stack.pop()
            to_remove = next(
                key for key, value in self.capture_groups.items() if value == first
            )
            match = f"(#eq? {to_remove} :[_])"
            self.query_str = comby.rewrite(self.query_str, match, "")
            self.capture_groups.pop(to_remove, None)
            for child in first.named_children:
                stack.append(child)

    def replace_with_tags(self, replace_str: str) -> str:
        """
        Replace nodes with their corresponding capture group in the replace_str.

        :param replace_str: The string that needs replacement.
        :return: The replaced string.
        """
        for capture_group, node in sorted(
            self.capture_groups.items(), key=lambda x: -len(x[1].text)
        ):
            if capture_group not in self.capture_groups.keys():
                continue
            text_repr = NodeUtils.convert_to_source(node)
            if text_repr in replace_str:
                # self.simplify_query(capture_group)
                replace_str = replace_str.replace(text_repr, f"{capture_group}")
        return replace_str


@attr.s
class Inference:
    """
    Class to represent inference on nodes.

    :ivar nodes_before: The list of nodes before the transformation.
    :ivar nodes_after: The list of nodes after the transformation.
    :ivar template_holes: The template holes for the inference.
    :ivar name: Name of the inference rule.
    :ivar _counter: A class-wide counter for naming the inference rules."""

    nodes_before = attr.ib(type=List[Node], validator=attr.validators.instance_of(list))
    nodes_after = attr.ib(type=List[Node], validator=attr.validators.instance_of(list))
    template_holes = attr.ib(type=Dict[str, List[str]], default=attr.Factory(dict))
    name = attr.ib(
        init=False,
    )

    _counter = 0

    def __attrs_post_init__(self):
        """
        Initialization method to increment the counter and set the name for the rule.
        """
        type(self)._counter += 1
        self.name = f"rule_{type(self)._counter}"

    def static_infer(self) -> RawRule:
        """
        Infer a raw rule based on the nodes before and after.

        :return: A raw rule inferred from the nodes."""
        if len(self.nodes_after) > 0 and len(self.nodes_before) > 0:
            return self.create_rule(self.nodes_before, self.nodes_after)
        elif len(self.nodes_after) > 0:
            raise self.create_addition()
        elif len(self.nodes_before) > 0:
            return self.create_rule(self.nodes_before, [])

    def find_nodes_to_change(self, node_before: Node, node_after: Node):
        """
        Function to find nodes to change if there's only one diverging node.

        :param node_before: The node before the change.
        :param node_after: The node after the change.
        :return: The nodes that need to be changed.
        """
        diverging_nodes = []
        if node_before.type == node_after.type:
            # Check if there's only one and only one diverging node
            # If there's more than one, then we can't do anything
            diverging_nodes = [
                (child_before, child_after)
                for child_before, child_after in zip(
                    node_before.named_children, node_after.named_children
                )
                if NodeUtils.convert_to_source(child_before)
                != NodeUtils.convert_to_source(child_after)
            ]

        if (
            len(diverging_nodes) == 1
            and node_before.child_count == node_after.child_count
        ):
            return self.find_nodes_to_change(*diverging_nodes[0])

        return node_before, node_after

    def create_rule(self, nodes_before: List[Node], nodes_after: List[Node]) -> RawRule:
        """
        Create a rule based on the nodes before and after.

        :param nodes_before: The list of nodes before the change.
        :param nodes_after: The list of nodes after the change.
        :return: A raw rule representing the transformation from nodes_before to nodes_after.
        """
        if len(nodes_before) == 1:
            if len(nodes_after) == 1:
                nodes_before[0], nodes_after[0] = self.find_nodes_to_change(
                    nodes_before[0], nodes_after[0]
                )
            node = nodes_before[0]
            qw = QueryWriter([node], self.template_holes)
            query = qw.write()

            lines_affected = " ".join(
                [NodeUtils.convert_to_source(node) for node in nodes_after]
            )
            replacement_str = qw.replace_with_tags(lines_affected)

            return RawRule(
                name=self.name,
                query=query,
                replace_node=qw.outer_most_node[1:],
                replace=replacement_str,
            )

        # If there are multiple nodes
        else:
            ancestor = NodeUtils.find_lowest_common_ancestor(nodes_before)
            replacement_str = NodeUtils.convert_to_source(
                ancestor, exclude=nodes_before
            )
            replacement_str = replacement_str.replace(
                "{placeholder}", "", len(nodes_before) - 1
            )

            lines_affected = " ".join(
                [NodeUtils.convert_to_source(node) for node in nodes_after]
            )
            replacement_str = replacement_str.replace(
                "{placeholder}", lines_affected, 1
            )
            qw = QueryWriter([ancestor], self.template_holes)
            qw.write()
            replacement_str = qw.replace_with_tags(replacement_str)

            return RawRule(
                name=self.name,
                query=qw.query_str,
                replace_node=qw.outer_most_node[1:],
                replace=replacement_str,
            )

    def create_addition(self) -> str:
        """
        A method to create addition rules. Currently not implemented.

        :raise: NotImplementedError"""
        raise NotImplementedError
