tools/reachability-analysis/lib/core.py (231 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import array import mmap import os import shutil import struct import subprocess import tempfile class ReachableObjectType(object): ANNO = 0 CLASS = 1 FIELD = 2 METHOD = 3 SEED = 4 @staticmethod def to_string(v): if v == ReachableObjectType.ANNO: return "ANNO" if v == ReachableObjectType.CLASS: return "CLASS" if v == ReachableObjectType.FIELD: return "FIELD" if v == ReachableObjectType.METHOD: return "METHOD" if v == ReachableObjectType.SEED: return "SEED" # Aside from classes and annotations, the other nodes will never have collisions # in their node names. Thus, we are able to infer their node type just by # looking at their names. The functions below help with that. def is_method(node_name): return "(" in node_name def is_field(node_name): return ":" in node_name and not is_method(node_name) def is_seed(node_name): return node_name == "<SEED>" def show_list_with_idx(list): ret = "" i = 0 while i < len(list): ret += "%d: %s\n" % (i, list[i]) i += 1 return ret def download_from_everstore(handle, filename): subprocess.check_call(["clowder", "get", handle, filename]) class ReachableObject(object): def __init__(self, type, name): self.type = type self.name = name self.preds = {} self.succs = {} def __str__(self): return "%s: %s\n" % (ReachableObjectType.to_string(self.type), self.name) def __repr__(self): ret = "%s: %s\n" % (ReachableObjectType.to_string(self.type), self.name) ret += "Reachable from %d predecessor(s):\n" % len(self.preds) ret += show_list_with_idx(list(self.preds.keys())) ret += "Reaching %d successor(s):\n" % len(self.succs) ret += show_list_with_idx(list(self.succs.keys())) return ret class ReachableMethod(ReachableObject): # we need override info for a method def __init__(self, ro, mog): self.type = ro.type self.name = ro.name self.preds = ro.preds self.succs = ro.succs self.overriding = [] self.overriden_by = [] if self.name in mog.nodes.keys(): n = mog.nodes[self.name] self.overriding = n.parents self.overriden_by = n.children def __repr__(self): ret = super(ReachableMethod, self).__repr__() if len(self.overriding) != 0: ret += "Overriding %s methods:\n" % len(self.overriding) ret += show_list_with_idx(list(map(lambda n: n.name, self.overriding))) if len(self.overriden_by) != 0: ret += "Overriden by %s methods:\n" % len(self.overriden_by) ret += show_list_with_idx(list(map(lambda n: n.name, self.overriden_by))) return ret class AbstractGraph(object): """ This contains the deserialization counterpart to the graph serialization code in BinarySerialization.h. """ def __init__(self): self.nodes = {} def expected_version(self): raise NotImplementedError() def read_node(self, mapping): raise NotImplementedError() def add_node(self, node): raise NotImplementedError() def add_edge(self, n1, n2): raise NotImplementedError() def list_nodes(self, search_str=None): raise NotImplementedError() def read_header(self, mapping): magic = struct.unpack("<L", mapping.read(4))[0] if magic != 0xFACEB000: raise Exception("Magic number mismatch") version = struct.unpack("<L", mapping.read(4))[0] if version != self.expected_version(): raise Exception("Version mismatch") def load(self, fn): with open(fn) as f: mapping = mmap.mmap(f.fileno(), 0, prot=mmap.PROT_READ) self.read_header(mapping) nodes_count = struct.unpack("<L", mapping.read(4))[0] nodes = [None] * nodes_count out_edges = [None] * nodes_count for i in range(nodes_count): node = self.read_node(mapping) nodes[i] = node self.add_node(node) edges_size = struct.unpack("<L", mapping.read(4))[0] out_edges[i] = array.array("I") out_edges[i].frombytes(mapping.read(4 * edges_size)) for i in range(nodes_count): node = nodes[i] for target in out_edges[i]: target_node = nodes[target] self.add_edge(node, target_node) def __repr__(self): sorted_keys = sorted(self.nodes.keys()) return "[" + ",\n".join([self.nodes[k].__repr__() for k in sorted_keys]) + "]" class ReachabilityGraph(AbstractGraph): @staticmethod def expected_version(): return 1 def read_node(self, mapping): node_type = struct.unpack("<B", mapping.read(1))[0] node_name_size = struct.unpack("<L", mapping.read(4))[0] node_name = mapping.read(node_name_size).decode("ascii") return ReachableObject(node_type, node_name) def add_node(self, node): self.nodes[(node.type, node.name)] = node def list_nodes(self, search_str=None): for key in self.nodes.keys(): type = ReachableObjectType.to_string(key[0]) name = key[1] if search_str is None or search_str in name: print('(ReachableObjectType.%s, "%s")' % (type, name)) @staticmethod def add_edge(n1, n2): if n1 not in n2.succs: # We store the edges as a dictionary because lookup times are much # faster with dictionaries than with lists. # The value isn't important - a None would do n2.succs[n1] = None if n2 not in n1.preds: n1.preds[n2] = None def get_node(self, node_name): if is_method(node_name): return self.nodes[(ReachableObjectType.METHOD, node_name)] if is_field(node_name): return self.nodes[(ReachableObjectType.FIELD, node_name)] # If we get here, we may have an annotation or a class. Just assume # we have a class. Users should call `get_anno` if they want to # retrieve an annotation. return self.nodes[(ReachableObjectType.CLASS, node_name)] def get_anno(self, node_name): return self.nodes[(ReachableObjectType.ANNO, node_name)] def get_seed(self, node_name): return self.nodes[(ReachableObjectType.SEED, node_name)] class MethodOverrideGraph(AbstractGraph): class Node(object): def __init__(self, name): self.name = name self.parents = [] self.children = [] def __init__(self): self.nodes = {} @staticmethod def expected_version(): return 1 def read_node(self, mapping): node_name_size = struct.unpack("<L", mapping.read(4))[0] node_name = mapping.read(node_name_size).decode("ascii") return self.Node(node_name) def add_node(self, node): self.nodes[node.name] = node def list_nodes(self, search_str=None): for key in self.nodes.keys(): if search_str is None or search_str in key: print('"%s"' % key) @staticmethod def add_edge(method, child): method.children.append(child) child.parents.append(method) class CombinedGraph(object): def __init__(self, reachability, method_override): self.reachability_graph = ReachabilityGraph() self.reachability_graph.load(reachability) self.method_override_graph = MethodOverrideGraph() self.method_override_graph.load(method_override) # extract information from the override graph for (type, name) in self.reachability_graph.nodes: if type == ReachableObjectType.METHOD: self.reachability_graph.nodes[(type, name)] = ReachableMethod( self.reachability_graph.nodes[(type, name)], self.method_override_graph, ) for method in self.method_override_graph.nodes.keys(): method_node = self.reachability_graph.get_node(method) for child in method_node.overriden_by: # find child in reachability graph, then build edge method_child = self.reachability_graph.get_node(child.name) for pred in method_node.preds: if pred.type == ReachableObjectType.METHOD: self.reachability_graph.add_edge(method_child, pred) self.nodes = self.reachability_graph.nodes @staticmethod def from_everstore(reachability, method_override): temp_dir = tempfile.mkdtemp() r_tmp = os.path.join(temp_dir, "redex-reachability.graph") download_from_everstore(reachability, r_tmp) mog_tmp = os.path.join(temp_dir, "redex-method-override.graph") download_from_everstore(method_override, mog_tmp) ret = CombinedGraph(r_tmp, mog_tmp) shutil.rmtree(temp_dir) return ret def node(self, search_str=None, search_type=None): node = None known_names = [] for (type, name) in self.nodes.keys(): if search_type is not None and type != search_type: # Classes and Annotations may have naming collisions # if that happens, use the search_type argument to filter continue if search_str is None or search_str in name: known_names += [(type, name)] if search_str is not None and len(known_names) == 1: # know exactly one node = self.nodes[known_names[0]] elif search_str is not None: # there could be names containing name of another node # in this case we prefer the only exact match exact_match = list(filter((lambda n: n[1] == search_str), known_names)) if len(exact_match) == 1: node = self.nodes[exact_match[0]] # if after all we still can't get which one does the user want, # print all options if node is None: print("Found %s matching names:" % len(known_names)) idx = 0 for (type, name) in known_names: print( '%d: (ReachableObjectType.%s, "%s")' % (idx, ReachableObjectType.to_string(type), name) ) idx += 1 return lambda i: self.nodes[known_names[i]] return node