# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from functools import lru_cache
from itertools import islice
from typing import Any, Dict, Generator, Iterable, List, NamedTuple, Optional, TypeVar

from .connection import PyreConnection, PyreQueryResult


T = TypeVar("T")


class Attributes(NamedTuple):
    name: str
    annotation: Optional[str]
    kind: str
    final: bool


class DefineParameter(NamedTuple):
    name: str
    annotation: str


class Define(NamedTuple):
    name: str
    parameters: List[DefineParameter]
    return_annotation: str

    def get_class_name(self) -> str:
        return ".".join(self.name.split(".")[:-1])

    def get_method_name(self) -> str:
        return self.name.split(".")[-1]


class Position(NamedTuple):
    line: int
    column: int


class Location(NamedTuple):
    path: str
    start: Position
    stop: Position


@dataclass(frozen=True)
class Annotation:
    type_name: str
    start: Position
    stop: Position


class CallGraphTarget:
    def __init__(self, call: Dict[str, Any]) -> None:
        self.target: str = ""
        if "target" in call:
            self.target = call["target"]
        else:
            self.target = call["direct_target"]
        self.kind: str = call["kind"]
        self.locations: List[Location] = [
            _parse_location(location) for location in call["locations"]
        ]

    def __eq__(self, other: "CallGraphTarget") -> bool:
        return (
            self.target == other.target
            and self.kind == other.kind
            and self.locations == other.locations
        )


class ClassHierarchy:
    def __init__(self, hierarchy: Dict[str, List[str]]) -> None:
        self.hierarchy = hierarchy

    # Poor man's cached property.
    @property
    @lru_cache(maxsize=1)
    def reverse_hierarchy(self) -> Dict[str, List[str]]:
        reversed_mapping = {}
        # In order to distinguish between missing types and types
        # with no subclasses, we initialize everything to [] for known keys.
        for key in self.hierarchy:
            reversed_mapping[key] = []
        for key, values in self.hierarchy.items():
            for value in values:
                reversed_mapping[value].append(key)
        return reversed_mapping

    def subclasses(self, class_name: str) -> List[str]:
        return self.reverse_hierarchy.get(class_name, [])

    def superclasses(self, class_name: str) -> List[str]:
        return self.hierarchy.get(class_name, [])


@dataclass
class PyreCache:
    class_hierarchy: Optional[ClassHierarchy] = None


class InvalidModel(NamedTuple):
    fully_qualified_name: str
    path: Optional[str]
    line: int
    column: int
    stop_line: int
    stop_column: int
    full_error_message: str


def _defines(pyre_connection: PyreConnection, modules: Iterable[str]) -> List[Define]:
    query = "defines({})".format(",".join(modules))
    result = pyre_connection.query_server(query)
    return [
        Define(
            name=element["name"],
            parameters=[
                DefineParameter(
                    name=parameter["name"], annotation=parameter["annotation"]
                )
                for parameter in element["parameters"]
            ],
            return_annotation=element["return_annotation"],
        )
        for element in result["response"]
    ]


def defines(
    pyre_connection: PyreConnection,
    modules: Iterable[str],
    batch_size: Optional[int] = None,
) -> List[Define]:
    modules = list(modules)
    if batch_size is None:
        return _defines(pyre_connection, modules)
    if batch_size <= 0:
        raise ValueError(
            "batch_size must a positive integer, provided: `{}`".format(batch_size)
        )
    found_defines: List[Define] = []
    module_chunks = [
        modules[index : index + batch_size]
        for index in range(0, len(modules), batch_size)
    ]
    for modules in module_chunks:
        found_defines.extend(_defines(pyre_connection, modules))
    return found_defines


def get_class_hierarchy(pyre_connection: PyreConnection) -> ClassHierarchy:
    result = pyre_connection.query_server("dump_class_hierarchy()")

    return ClassHierarchy(
        {
            key: edges
            for annotation_and_edges in result["response"]
            for key, edges in annotation_and_edges.items()
        }
    )


def get_cached_class_hierarchy(
    pyre_connection: PyreConnection, pyre_cache: Optional[PyreCache]
) -> ClassHierarchy:
    cached_class_hierarchy = (
        pyre_cache.class_hierarchy if pyre_cache is not None else None
    )
    if cached_class_hierarchy is not None:
        return cached_class_hierarchy

    class_hierarchy = get_class_hierarchy(pyre_connection)

    if pyre_cache is not None:
        pyre_cache.class_hierarchy = class_hierarchy

    return class_hierarchy


def _annotations_per_file(data: PyreQueryResult) -> Dict[str, List[Annotation]]:
    def make_position(mapping: Dict[str, int]) -> Position:
        return Position(column=mapping["column"], line=mapping["line"])

    return {
        response["response"][0]["path"]: [
            Annotation(
                locations_and_annotations["annotation"],
                make_position(locations_and_annotations["location"]["start"]),
                make_position(locations_and_annotations["location"]["stop"]),
            )
            for locations_and_annotations in response["response"][0]["types"]
        ]
        for response in data["response"]
        if "response" in response
    }


def get_types(
    pyre_connection: PyreConnection, *paths: str
) -> Dict[str, List[Annotation]]:
    types_sequence = ",".join([f"types('{path}')" for path in paths])
    result = pyre_connection.query_server(f"batch({types_sequence})")

    return _annotations_per_file(result)


def get_superclasses(pyre_connection: PyreConnection, class_name: str) -> List[str]:
    query = f"superclasses({class_name})"
    result = pyre_connection.query_server(query)
    return result["response"][0][class_name]


def _get_batch(
    iterable: Iterable[T], batch_size: Optional[int]
) -> Generator[Iterable[T], None, None]:
    if not batch_size:
        yield iterable
    elif batch_size <= 0:
        raise ValueError(
            "batch_size must a positive integer, provided: `{}`".format(batch_size)
        )
    else:
        iterator = iter(iterable)
        batch = list(islice(iterator, batch_size))
        while batch:
            yield batch
            batch = list(islice(iterator, batch_size))


def _get_attributes(
    pyre_connection: PyreConnection, class_name: str
) -> List[Attributes]:
    query = f"attributes({class_name})"
    response = pyre_connection.query_server(query)["response"]
    return [
        Attributes(
            name=attribute["name"],
            annotation=attribute["annotation"],
            kind=attribute["kind"],
            final=attribute["final"],
        )
        for attribute in response["attributes"]
    ]


def get_attributes(
    pyre_connection: PyreConnection,
    class_names: Iterable[str],
    batch_size: Optional[int] = None,
) -> Dict[str, List[Attributes]]:
    all_responses = {}
    for batch in _get_batch(class_names, batch_size):
        query = "batch({})".format(", ".join([f"attributes({name})" for name in batch]))
        responses = pyre_connection.query_server(query)["response"]
        all_responses.update(
            {
                class_name: [
                    Attributes(
                        name=attribute["name"],
                        annotation=attribute["annotation"],
                        kind=attribute["kind"],
                        final=attribute["final"],
                    )
                    for attribute in response["response"]["attributes"]
                ]
                for class_name, response in zip(batch, responses)
            }
        )
    return all_responses


def get_call_graph(
    pyre_connection: PyreConnection,
) -> Optional[Dict[str, List[CallGraphTarget]]]:
    response = pyre_connection.query_server("dump_call_graph()")["response"]
    call_graph = {}

    for function, calls in response.items():
        call_graph[function] = [CallGraphTarget(call) for call in calls]
    return call_graph


def _parse_location(location_json: Dict[str, Any]) -> Location:
    return Location(
        path=location_json["path"],
        start=_parse_position(location_json["start"]),
        stop=_parse_position(location_json["stop"]),
    )


def _parse_position(position_json: Dict[str, Any]) -> Position:
    return Position(line=position_json["line"], column=position_json["column"])


def get_invalid_taint_models(
    pyre_connection: PyreConnection,
) -> List[InvalidModel]:
    errors: List[InvalidModel] = []
    response = pyre_connection.query_server("validate_taint_models()")
    if "response" in response and "errors" in response["response"]:
        found_errors = response["response"]["errors"]
        for error in found_errors:
            errors.append(
                InvalidModel(
                    full_error_message=error["description"],
                    path=error["path"],
                    line=error["line"],
                    column=error["column"],
                    stop_line=error["stop_line"],
                    stop_column=error["stop_column"],
                    fully_qualified_name="",
                )
            )
    return errors
