tools/generate_taint_models/get_graphql_sources.py (68 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import os from importlib import import_module from typing import Any, Callable, Iterable, List, Type, Union from .generator_specifications import AllParametersAnnotation from .model import CallableModel from .model_generator import ModelGenerator # pyre-ignore: Too dynamic. GraphQLObjectType = Type[Any] class GraphQLSourceGenerator(ModelGenerator[CallableModel]): def __init__( self, graphql_module: Union[List[str], str], graphql_object_type: GraphQLObjectType, args_taint_annotation: str = "TaintSource[UserControlled]", return_taint_annotation: str = "TaintSink[ReturnedToUser]", ) -> None: super().__init__() self.graphql_module: Union[List[str], str] = graphql_module self.graphql_object_type: GraphQLObjectType = graphql_object_type self.args_taint_annotation: str = args_taint_annotation self.return_taint_annotation: str = return_taint_annotation def gather_functions_to_model(self) -> Iterable[Callable[..., object]]: # Get all graphql import names. views: List[Callable[..., object]] = [] modules = [] module_argument = self.graphql_module graphql_modules = ( [module_argument] if isinstance(module_argument, str) else module_argument ) for graphql_module in graphql_modules: for path in os.listdir( os.path.dirname(import_module(graphql_module).__file__) ): if path.endswith(".py") and path != "__init__.py": modules.append(f"{graphql_module}.{path[:-3]}") def visit_all_graphql_resolvers(module_name: str) -> None: module = import_module(module_name) for key in module.__dict__: element = module.__dict__[key] if not isinstance(element, self.graphql_object_type): continue try: fields = element.fields except AssertionError: # GraphQL throws an exception when a GraphQL object is created # with 0 fields. Since we don't control the library, we need to # program defensively here :( fields = [] for field in fields: resolver = fields[field].resolve if resolver is not None and resolver.__name__ != "<lambda>": views.append(resolver) for module_name in modules: visit_all_graphql_resolvers(module_name) return views def compute_models( self, functions_to_model: Iterable[Callable[..., object]] ) -> Iterable[CallableModel]: graphql_models = set() for view_function in functions_to_model: try: model = CallableModel( callable_object=view_function, parameter_annotation=AllParametersAnnotation( vararg=self.args_taint_annotation, kwarg=self.args_taint_annotation, ), returns=self.return_taint_annotation, ) graphql_models.add(model) except ValueError: pass return sorted(graphql_models)