api/query.py (250 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from dataclasses import dataclass from functools import lru_cache from itertools import islice from typing import Any, Dict, Generator, Iterable, List, NamedTuple, Optional, TypeVar from .connection import PyreConnection, PyreQueryResult T = TypeVar("T") class Attributes(NamedTuple): name: str annotation: Optional[str] kind: str final: bool class DefineParameter(NamedTuple): name: str annotation: str class Define(NamedTuple): name: str parameters: List[DefineParameter] return_annotation: str def get_class_name(self) -> str: return ".".join(self.name.split(".")[:-1]) def get_method_name(self) -> str: return self.name.split(".")[-1] class Position(NamedTuple): line: int column: int class Location(NamedTuple): path: str start: Position stop: Position @dataclass(frozen=True) class Annotation: type_name: str start: Position stop: Position class CallGraphTarget: def __init__(self, call: Dict[str, Any]) -> None: self.target: str = "" if "target" in call: self.target = call["target"] else: self.target = call["direct_target"] self.kind: str = call["kind"] self.locations: List[Location] = [ _parse_location(location) for location in call["locations"] ] def __eq__(self, other: "CallGraphTarget") -> bool: return ( self.target == other.target and self.kind == other.kind and self.locations == other.locations ) class ClassHierarchy: def __init__(self, hierarchy: Dict[str, List[str]]) -> None: self.hierarchy = hierarchy # Poor man's cached property. @property @lru_cache(maxsize=1) def reverse_hierarchy(self) -> Dict[str, List[str]]: reversed_mapping = {} # In order to distinguish between missing types and types # with no subclasses, we initialize everything to [] for known keys. for key in self.hierarchy: reversed_mapping[key] = [] for key, values in self.hierarchy.items(): for value in values: reversed_mapping[value].append(key) return reversed_mapping def subclasses(self, class_name: str) -> List[str]: return self.reverse_hierarchy.get(class_name, []) def superclasses(self, class_name: str) -> List[str]: return self.hierarchy.get(class_name, []) @dataclass class PyreCache: class_hierarchy: Optional[ClassHierarchy] = None class InvalidModel(NamedTuple): fully_qualified_name: str path: Optional[str] line: int column: int stop_line: int stop_column: int full_error_message: str def _defines(pyre_connection: PyreConnection, modules: Iterable[str]) -> List[Define]: query = "defines({})".format(",".join(modules)) result = pyre_connection.query_server(query) return [ Define( name=element["name"], parameters=[ DefineParameter( name=parameter["name"], annotation=parameter["annotation"] ) for parameter in element["parameters"] ], return_annotation=element["return_annotation"], ) for element in result["response"] ] def defines( pyre_connection: PyreConnection, modules: Iterable[str], batch_size: Optional[int] = None, ) -> List[Define]: modules = list(modules) if batch_size is None: return _defines(pyre_connection, modules) if batch_size <= 0: raise ValueError( "batch_size must a positive integer, provided: `{}`".format(batch_size) ) found_defines: List[Define] = [] module_chunks = [ modules[index : index + batch_size] for index in range(0, len(modules), batch_size) ] for modules in module_chunks: found_defines.extend(_defines(pyre_connection, modules)) return found_defines def get_class_hierarchy(pyre_connection: PyreConnection) -> ClassHierarchy: result = pyre_connection.query_server("dump_class_hierarchy()") return ClassHierarchy( { key: edges for annotation_and_edges in result["response"] for key, edges in annotation_and_edges.items() } ) def get_cached_class_hierarchy( pyre_connection: PyreConnection, pyre_cache: Optional[PyreCache] ) -> ClassHierarchy: cached_class_hierarchy = ( pyre_cache.class_hierarchy if pyre_cache is not None else None ) if cached_class_hierarchy is not None: return cached_class_hierarchy class_hierarchy = get_class_hierarchy(pyre_connection) if pyre_cache is not None: pyre_cache.class_hierarchy = class_hierarchy return class_hierarchy def _annotations_per_file(data: PyreQueryResult) -> Dict[str, List[Annotation]]: def make_position(mapping: Dict[str, int]) -> Position: return Position(column=mapping["column"], line=mapping["line"]) return { response["response"][0]["path"]: [ Annotation( locations_and_annotations["annotation"], make_position(locations_and_annotations["location"]["start"]), make_position(locations_and_annotations["location"]["stop"]), ) for locations_and_annotations in response["response"][0]["types"] ] for response in data["response"] if "response" in response } def get_types( pyre_connection: PyreConnection, *paths: str ) -> Dict[str, List[Annotation]]: types_sequence = ",".join([f"types('{path}')" for path in paths]) result = pyre_connection.query_server(f"batch({types_sequence})") return _annotations_per_file(result) def get_superclasses(pyre_connection: PyreConnection, class_name: str) -> List[str]: query = f"superclasses({class_name})" result = pyre_connection.query_server(query) return result["response"][0][class_name] def _get_batch( iterable: Iterable[T], batch_size: Optional[int] ) -> Generator[Iterable[T], None, None]: if not batch_size: yield iterable elif batch_size <= 0: raise ValueError( "batch_size must a positive integer, provided: `{}`".format(batch_size) ) else: iterator = iter(iterable) batch = list(islice(iterator, batch_size)) while batch: yield batch batch = list(islice(iterator, batch_size)) def _get_attributes( pyre_connection: PyreConnection, class_name: str ) -> List[Attributes]: query = f"attributes({class_name})" response = pyre_connection.query_server(query)["response"] return [ Attributes( name=attribute["name"], annotation=attribute["annotation"], kind=attribute["kind"], final=attribute["final"], ) for attribute in response["attributes"] ] def get_attributes( pyre_connection: PyreConnection, class_names: Iterable[str], batch_size: Optional[int] = None, ) -> Dict[str, List[Attributes]]: all_responses = {} for batch in _get_batch(class_names, batch_size): query = "batch({})".format(", ".join([f"attributes({name})" for name in batch])) responses = pyre_connection.query_server(query)["response"] all_responses.update( { class_name: [ Attributes( name=attribute["name"], annotation=attribute["annotation"], kind=attribute["kind"], final=attribute["final"], ) for attribute in response["response"]["attributes"] ] for class_name, response in zip(batch, responses) } ) return all_responses def get_call_graph( pyre_connection: PyreConnection, ) -> Optional[Dict[str, List[CallGraphTarget]]]: response = pyre_connection.query_server("dump_call_graph()")["response"] call_graph = {} for function, calls in response.items(): call_graph[function] = [CallGraphTarget(call) for call in calls] return call_graph def _parse_location(location_json: Dict[str, Any]) -> Location: return Location( path=location_json["path"], start=_parse_position(location_json["start"]), stop=_parse_position(location_json["stop"]), ) def _parse_position(position_json: Dict[str, Any]) -> Position: return Position(line=position_json["line"], column=position_json["column"]) def get_invalid_taint_models( pyre_connection: PyreConnection, ) -> List[InvalidModel]: errors: List[InvalidModel] = [] response = pyre_connection.query_server("validate_taint_models()") if "response" in response and "errors" in response["response"]: found_errors = response["response"]["errors"] for error in found_errors: errors.append( InvalidModel( full_error_message=error["description"], path=error["path"], line=error["line"], column=error["column"], stop_line=error["stop_line"], stop_column=error["stop_column"], fully_qualified_name="", ) ) return errors