tools/generate_taint_models/model.py (300 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import _ast
import abc
import ast
import logging
from typing import Callable, Iterable, List, Optional, Set, Union
from ...api import query
from .generator_specifications import (
AnnotationSpecification,
ParameterAnnotation,
WhitelistSpecification,
)
from .inspect_parser import extract_parameters, extract_qualified_name
from .parameter import Parameter
FunctionDefinition = Union[_ast.FunctionDef, _ast.AsyncFunctionDef]
LOG: logging.Logger = logging.getLogger(__name__)
class Model(abc.ABC):
def __lt__(self, other: "Model") -> bool:
return str(self) < str(other)
@abc.abstractmethod
def __eq__(self) -> int:
...
@abc.abstractmethod
def __hash__(self) -> int:
...
class RawCallableModel(Model):
callable_name: str
parameters: List[Parameter]
annotations: AnnotationSpecification
whitelist: WhitelistSpecification
returns: Optional[str] = None
def __init__(
self,
parameter_annotation: Optional[ParameterAnnotation] = None,
returns: Optional[str] = None,
parameter_type_whitelist: Optional[Iterable[str]] = None,
parameter_name_whitelist: Optional[Set[str]] = None,
annotations: Optional[AnnotationSpecification] = None,
whitelist: Optional[WhitelistSpecification] = None,
) -> None:
if annotations:
self.annotations = annotations
else:
self.annotations = AnnotationSpecification(
parameter_annotation=parameter_annotation, returns=returns
)
if whitelist:
self.whitelist = whitelist
else:
self.whitelist = WhitelistSpecification(
parameter_type=set(parameter_type_whitelist)
if parameter_type_whitelist
else None,
parameter_name=parameter_name_whitelist,
)
callable_name = self._get_fully_qualified_callable_name()
# Object construction should fail if any child class passes in a None.
if not callable_name or "-" in callable_name:
raise ValueError("The callable is not supported")
self.callable_name = callable_name
self.parameters = self._generate_parameters()
@abc.abstractmethod
def _generate_parameters(self) -> List["Parameter"]:
...
@abc.abstractmethod
def _get_fully_qualified_callable_name(self) -> Optional[str]:
...
def __str__(self) -> str:
serialized_parameters = []
name_whitelist = self.whitelist.parameter_name
type_whitelist = self.whitelist.parameter_type
for parameter in self.parameters:
should_annotate = True
if name_whitelist is not None and parameter.name in name_whitelist:
should_annotate = False
if type_whitelist is not None and parameter.annotation in type_whitelist:
should_annotate = False
if should_annotate:
parameter_annotation = self.annotations.parameter_annotation
if parameter_annotation is not None:
taint = parameter_annotation.get(parameter)
else:
taint = None
else:
taint = None
# * parameters indicate kwargs after the parameter position, and can't be
# tainted. Example: `def foo(x, *, y): ...`
if parameter.name != "*" and taint:
serialized_parameters.append(f"{parameter.name}: {taint}")
else:
serialized_parameters.append(parameter.name)
returns = self.annotations.returns
if returns:
return_annotation = f" -> {returns}"
else:
return_annotation = ""
return (
f"def {self.callable_name}({', '.join(serialized_parameters)})"
f"{return_annotation}: ..."
)
def __eq__(self, other: object) -> bool:
if not isinstance(other, RawCallableModel):
return False
return (
self.callable_name == other.callable_name
and self.parameters == other.parameters
)
# Need to explicitly define this(despite baseclass) as we are overriding eq
def __hash__(self) -> int:
parameter_names_string = ",".join(
map(
lambda parameter: f"{parameter.name}:{parameter.annotation}"
if parameter.annotation
else f"{parameter.name}:_empty",
self.parameters,
)
)
return hash((self.callable_name, parameter_names_string))
class CallableModel(RawCallableModel):
callable_object: Callable[..., object]
def __init__(
self,
callable_object: Callable[..., object],
parameter_annotation: Optional[ParameterAnnotation] = None,
returns: Optional[str] = None,
parameter_type_whitelist: Optional[Iterable[str]] = None,
parameter_name_whitelist: Optional[Set[str]] = None,
annotations: Optional[AnnotationSpecification] = None,
whitelist: Optional[WhitelistSpecification] = None,
) -> None:
self.callable_object = callable_object
super().__init__(
parameter_annotation=parameter_annotation,
returns=returns,
parameter_type_whitelist=parameter_type_whitelist,
parameter_name_whitelist=parameter_name_whitelist,
annotations=annotations,
whitelist=whitelist,
)
def _generate_parameters(self) -> List[Parameter]:
return extract_parameters(self.callable_object)
def _get_fully_qualified_callable_name(self) -> Optional[str]:
return extract_qualified_name(self.callable_object)
class FunctionDefinitionModel(RawCallableModel):
definition: FunctionDefinition
qualifier: Optional[str] = None
def __init__(
self,
definition: FunctionDefinition,
qualifier: Optional[str] = None,
parameter_annotation: Optional[ParameterAnnotation] = None,
returns: Optional[str] = None,
parameter_type_whitelist: Optional[Iterable[str]] = None,
parameter_name_whitelist: Optional[Set[str]] = None,
annotations: Optional[AnnotationSpecification] = None,
whitelist: Optional[WhitelistSpecification] = None,
) -> None:
self.definition = definition
self.qualifier = qualifier
super().__init__(
parameter_annotation=parameter_annotation,
returns=returns,
parameter_type_whitelist=parameter_type_whitelist,
parameter_name_whitelist=parameter_name_whitelist,
annotations=annotations,
whitelist=whitelist,
)
@staticmethod
def _get_annotation(ast_arg: ast.arg) -> Optional[str]:
annotation = ast_arg.annotation
if annotation and isinstance(annotation, _ast.Name):
return annotation.id
else:
return None
def _generate_parameters(self) -> List[Parameter]:
parameters: List[Parameter] = []
function_arguments = self.definition.args
for ast_arg in function_arguments.args:
parameters.append(
Parameter(
ast_arg.arg,
FunctionDefinitionModel._get_annotation(ast_arg),
Parameter.Kind.ARG,
)
)
keyword_only_parameters = function_arguments.kwonlyargs
if len(keyword_only_parameters) > 0:
parameters.append(
Parameter(name="*", annotation=None, kind=Parameter.Kind.ARG)
)
for parameter in keyword_only_parameters:
parameters.append(
Parameter(
parameter.arg,
FunctionDefinitionModel._get_annotation(parameter),
Parameter.Kind.ARG,
)
)
vararg_parameters = function_arguments.vararg
if isinstance(vararg_parameters, ast.arg):
parameters.append(
Parameter(
f"*{vararg_parameters.arg}",
FunctionDefinitionModel._get_annotation(vararg_parameters),
Parameter.Kind.VARARG,
)
)
kwarg_parameters = function_arguments.kwarg
if isinstance(kwarg_parameters, ast.arg):
parameters.append(
Parameter(
f"**{kwarg_parameters.arg}",
FunctionDefinitionModel._get_annotation(kwarg_parameters),
Parameter.Kind.KWARG,
)
)
return parameters
def _get_fully_qualified_callable_name(self) -> Optional[str]:
qualifier = f"{self.qualifier}." if self.qualifier else ""
fn_name = self.definition.name
return qualifier + fn_name
class PyreFunctionDefinitionModel(RawCallableModel):
definition: query.Define
def __init__(
self,
definition: query.Define,
parameter_annotation: Optional[ParameterAnnotation] = None,
returns: Optional[str] = None,
parameter_type_whitelist: Optional[Iterable[str]] = None,
parameter_name_whitelist: Optional[Set[str]] = None,
annotations: Optional[AnnotationSpecification] = None,
whitelist: Optional[WhitelistSpecification] = None,
) -> None:
self.definition = definition
super().__init__(
parameter_annotation=parameter_annotation,
returns=returns,
parameter_type_whitelist=parameter_type_whitelist,
parameter_name_whitelist=parameter_name_whitelist,
annotations=annotations,
whitelist=whitelist,
)
def _generate_parameters(self) -> List[Parameter]:
parameters: List[Parameter] = []
for parameter in self.definition.parameters:
if "**" in parameter.name:
kind = Parameter.Kind.KWARG
elif "*" in parameter.name:
kind = Parameter.Kind.VARARG
else:
kind = Parameter.Kind.ARG
parameters.append(
Parameter(
name=parameter.name, annotation=parameter.annotation, kind=kind
)
)
return parameters
def _get_fully_qualified_callable_name(self) -> Optional[str]:
return self.definition.name
class AssignmentModel(Model):
annotation: str
target: str
def __init__(self, annotation: str, target: str) -> None:
if "-" in target:
raise ValueError("The target is not supported")
self.annotation = annotation
self.target = target
def __str__(self) -> str:
return f"{self.target}: {self.annotation} = ..."
def __eq__(self, other: object) -> bool:
if not isinstance(other, AssignmentModel):
return False
return self.target == other.target
def __hash__(self) -> int:
return hash(self.target)
class ClassModel(Model):
class_name: str
annotation: str
def __init__(self, class_name: str, annotation: str) -> None:
self.class_name = class_name
self.annotation = annotation
def __str__(self) -> str:
return f"class {self.class_name}({self.annotation}): ..."
def __eq__(self, other: object) -> bool:
if not isinstance(other, ClassModel):
return False
return self.class_name == other.class_name
def __hash__(self) -> int:
return hash(self.class_name)
class PropertyModel(Model):
def __init__(self, class_name: str, attribute_name: str, annotation: str) -> None:
self.class_name = class_name
self.attribute_name = attribute_name
self.annotation = annotation
def __str__(self) -> str:
return f"@property\ndef {self.class_name}.{self.attribute_name}(self) -> {self.annotation}: ..." # noqa B950
def __eq__(self, other: object) -> bool:
if not isinstance(other, PropertyModel):
return False
return (
self.class_name == other.class_name
and self.attribute_name == other.attribute_name
)
def __hash__(self) -> int:
return hash((self.class_name, self.attribute_name))