#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import array
import mmap
import os
import shutil
import struct
import subprocess
import tempfile


class ReachableObjectType(object):
    ANNO = 0
    CLASS = 1
    FIELD = 2
    METHOD = 3
    SEED = 4

    @staticmethod
    def to_string(v):
        if v == ReachableObjectType.ANNO:
            return "ANNO"
        if v == ReachableObjectType.CLASS:
            return "CLASS"
        if v == ReachableObjectType.FIELD:
            return "FIELD"
        if v == ReachableObjectType.METHOD:
            return "METHOD"
        if v == ReachableObjectType.SEED:
            return "SEED"


# Aside from classes and annotations, the other nodes will never have collisions
# in their node names. Thus, we are able to infer their node type just by
# looking at their names. The functions below help with that.


def is_method(node_name):
    return "(" in node_name


def is_field(node_name):
    return ":" in node_name and not is_method(node_name)


def is_seed(node_name):
    return node_name == "<SEED>"


def show_list_with_idx(list):
    ret = ""
    i = 0
    while i < len(list):
        ret += "%d: %s\n" % (i, list[i])
        i += 1

    return ret


def download_from_everstore(handle, filename):
    subprocess.check_call(["clowder", "get", handle, filename])


class ReachableObject(object):
    def __init__(self, type, name):
        self.type = type
        self.name = name
        self.preds = {}
        self.succs = {}

    def __str__(self):
        return "%s: %s\n" % (ReachableObjectType.to_string(self.type), self.name)

    def __repr__(self):
        ret = "%s: %s\n" % (ReachableObjectType.to_string(self.type), self.name)
        ret += "Reachable from %d predecessor(s):\n" % len(self.preds)
        ret += show_list_with_idx(list(self.preds.keys()))
        ret += "Reaching %d successor(s):\n" % len(self.succs)
        ret += show_list_with_idx(list(self.succs.keys()))
        return ret


class ReachableMethod(ReachableObject):
    # we need override info for a method
    def __init__(self, ro, mog):
        self.type = ro.type
        self.name = ro.name
        self.preds = ro.preds
        self.succs = ro.succs
        self.overriding = []
        self.overriden_by = []

        if self.name in mog.nodes.keys():
            n = mog.nodes[self.name]
            self.overriding = n.parents
            self.overriden_by = n.children

    def __repr__(self):
        ret = super(ReachableMethod, self).__repr__()
        if len(self.overriding) != 0:
            ret += "Overriding %s methods:\n" % len(self.overriding)
            ret += show_list_with_idx(list(map(lambda n: n.name, self.overriding)))

        if len(self.overriden_by) != 0:
            ret += "Overriden by %s methods:\n" % len(self.overriden_by)
            ret += show_list_with_idx(list(map(lambda n: n.name, self.overriden_by)))
        return ret


class AbstractGraph(object):
    """
    This contains the deserialization counterpart to the graph serialization
    code in BinarySerialization.h.
    """

    def __init__(self):
        self.nodes = {}

    def expected_version(self):
        raise NotImplementedError()

    def read_node(self, mapping):
        raise NotImplementedError()

    def add_node(self, node):
        raise NotImplementedError()

    def add_edge(self, n1, n2):
        raise NotImplementedError()

    def list_nodes(self, search_str=None):
        raise NotImplementedError()

    def read_header(self, mapping):
        magic = struct.unpack("<L", mapping.read(4))[0]
        if magic != 0xFACEB000:
            raise Exception("Magic number mismatch")
        version = struct.unpack("<L", mapping.read(4))[0]
        if version != self.expected_version():
            raise Exception("Version mismatch")

    def load(self, fn):
        with open(fn) as f:
            mapping = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ)
            self.read_header(mapping)
            nodes_count = struct.unpack("<L", mapping.read(4))[0]
            nodes = [None] * nodes_count
            out_edges = [None] * nodes_count
            for i in range(nodes_count):
                node = self.read_node(mapping)
                nodes[i] = node
                self.add_node(node)

                edges_size = struct.unpack("<L", mapping.read(4))[0]
                out_edges[i] = array.array("I")
                out_edges[i].frombytes(mapping.read(4 * edges_size))

            for i in range(nodes_count):
                node = nodes[i]
                for target in out_edges[i]:
                    target_node = nodes[target]
                    self.add_edge(node, target_node)

    def __repr__(self):
        sorted_keys = sorted(self.nodes.keys())
        return "[" + ",\n".join([self.nodes[k].__repr__() for k in sorted_keys]) + "]"


class ReachabilityGraph(AbstractGraph):
    @staticmethod
    def expected_version():
        return 1

    def read_node(self, mapping):
        node_type = struct.unpack("<B", mapping.read(1))[0]
        node_name_size = struct.unpack("<L", mapping.read(4))[0]
        node_name = mapping.read(node_name_size).decode("ascii")
        return ReachableObject(node_type, node_name)

    def add_node(self, node):
        self.nodes[(node.type, node.name)] = node

    def list_nodes(self, search_str=None):
        for key in self.nodes.keys():
            type = ReachableObjectType.to_string(key[0])
            name = key[1]
            if search_str is None or search_str in name:
                print('(ReachableObjectType.%s, "%s")' % (type, name))

    @staticmethod
    def add_edge(n1, n2):
        if n1 not in n2.succs:
            # We store the edges as a dictionary because lookup times are much
            # faster with dictionaries than with lists.
            # The value isn't important - a None would do
            n2.succs[n1] = None

        if n2 not in n1.preds:
            n1.preds[n2] = None

    def get_node(self, node_name):
        if is_method(node_name):
            return self.nodes[(ReachableObjectType.METHOD, node_name)]
        if is_field(node_name):
            return self.nodes[(ReachableObjectType.FIELD, node_name)]
        # If we get here, we may have an annotation or a class. Just assume
        # we have a class. Users should call `get_anno` if they want to
        # retrieve an annotation.
        return self.nodes[(ReachableObjectType.CLASS, node_name)]

    def get_anno(self, node_name):
        return self.nodes[(ReachableObjectType.ANNO, node_name)]

    def get_seed(self, node_name):
        return self.nodes[(ReachableObjectType.SEED, node_name)]


class MethodOverrideGraph(AbstractGraph):
    class Node(object):
        def __init__(self, name):
            self.name = name
            self.parents = []
            self.children = []

    def __init__(self):
        self.nodes = {}

    @staticmethod
    def expected_version():
        return 1

    def read_node(self, mapping):
        node_name_size = struct.unpack("<L", mapping.read(4))[0]
        node_name = mapping.read(node_name_size).decode("ascii")
        return self.Node(node_name)

    def add_node(self, node):
        self.nodes[node.name] = node

    def list_nodes(self, search_str=None):
        for key in self.nodes.keys():
            if search_str is None or search_str in key:
                print('"%s"' % key)

    @staticmethod
    def add_edge(method, child):
        method.children.append(child)
        child.parents.append(method)


class CombinedGraph(object):
    def __init__(self, reachability, method_override):
        self.reachability_graph = ReachabilityGraph()
        self.reachability_graph.load(reachability)
        self.method_override_graph = MethodOverrideGraph()
        self.method_override_graph.load(method_override)

        # extract information from the override graph
        for (type, name) in self.reachability_graph.nodes:
            if type == ReachableObjectType.METHOD:
                self.reachability_graph.nodes[(type, name)] = ReachableMethod(
                    self.reachability_graph.nodes[(type, name)],
                    self.method_override_graph,
                )

        for method in self.method_override_graph.nodes.keys():
            method_node = self.reachability_graph.get_node(method)
            for child in method_node.overriden_by:
                # find child in reachability graph, then build edge
                method_child = self.reachability_graph.get_node(child.name)
                for pred in method_node.preds:
                    if pred.type == ReachableObjectType.METHOD:
                        self.reachability_graph.add_edge(method_child, pred)

        self.nodes = self.reachability_graph.nodes

    @staticmethod
    def from_everstore(reachability, method_override):
        temp_dir = tempfile.mkdtemp()
        r_tmp = os.path.join(temp_dir, "redex-reachability.graph")
        download_from_everstore(reachability, r_tmp)
        mog_tmp = os.path.join(temp_dir, "redex-method-override.graph")
        download_from_everstore(method_override, mog_tmp)
        ret = CombinedGraph(r_tmp, mog_tmp)
        shutil.rmtree(temp_dir)
        return ret

    def node(self, search_str=None, search_type=None):
        node = None
        known_names = []
        for (type, name) in self.nodes.keys():
            if search_type is not None and type != search_type:
                # Classes and Annotations may have naming collisions
                # if that happens, use the search_type argument to filter
                continue
            if search_str is None or search_str in name:
                known_names += [(type, name)]

        if search_str is not None and len(known_names) == 1:
            # know exactly one
            node = self.nodes[known_names[0]]
        elif search_str is not None:
            # there could be names containing name of another node
            # in this case we prefer the only exact match
            exact_match = list(filter((lambda n: n[1] == search_str), known_names))
            if len(exact_match) == 1:
                node = self.nodes[exact_match[0]]

        # if after all we still can't get which one does the user want,
        # print all options
        if node is None:
            print("Found %s matching names:" % len(known_names))
            idx = 0
            for (type, name) in known_names:
                print(
                    '%d: (ReachableObjectType.%s, "%s")'
                    % (idx, ReachableObjectType.to_string(type), name)
                )
                idx += 1

            return lambda i: self.nodes[known_names[i]]

        return node
