scripts/explore_pysa_models.py (266 lines of code) (raw):

#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import collections import io import json import re import subprocess import textwrap from pathlib import Path from typing import Optional, Any, Dict, List, Tuple, Iterable, NamedTuple, Callable class FilePosition(NamedTuple): offset: int length: int __handle: Optional[io.BufferedReader] = None __model_index: Dict[str, FilePosition] = {} __issue_index: Dict[str, List[FilePosition]] = collections.defaultdict(list) __warned_missing_jq: bool = False def _iter_with_offset(lines: Iterable[bytes]) -> Iterable[Tuple[bytes, int]]: offset = 0 for line in lines: yield (line, offset) offset += len(line) def _resolve_taint_output_path(taint_output_filename: str) -> Path: taint_output_path = Path(taint_output_filename) if taint_output_path.is_dir(): taint_output_path = taint_output_path / "taint-output.json" return taint_output_path def index(taint_output_filename: str = "taint-output.json") -> None: """Index all available models in the given taint output file or directory.""" global __handle, __model_index, __issue_index print(f"Indexing `{taint_output_filename}`") taint_output_path = _resolve_taint_output_path(taint_output_filename) handle = open(taint_output_path, "rb") __handle = handle __model_index = {} __issue_index = collections.defaultdict(list) count_models = 0 count_issues = 0 for line, offset in _iter_with_offset(handle): message = json.loads(line) if "kind" not in message: continue if message["kind"] == "model": callable = message["data"]["callable"] assert callable not in __model_index __model_index[callable] = FilePosition(offset=offset, length=len(line)) count_models += 1 elif message["kind"] == "issue": callable = message["data"]["callable"] __issue_index[callable].append( FilePosition(offset=offset, length=len(line)) ) count_issues += 1 print(f"Indexed {count_models} models and {count_issues} issues") def _assert_loaded() -> io.BufferedReader: handle = __handle if handle is None or len(__model_index) == 0: raise AssertionError("call index() first") return handle def callables_containing(string: str) -> List[str]: """Find all callables containing the given string.""" _assert_loaded() return sorted(filter(lambda name: string in name, __model_index.keys())) def callables_matching(pattern: str) -> List[str]: """Find all callables matching the given regular expression.""" _assert_loaded() regex = re.compile(pattern) return sorted(filter(lambda name: re.search(regex, name), __model_index.keys())) def _read(position: FilePosition) -> bytes: handle = _assert_loaded() handle.seek(position.offset) return handle.read(position.length) def _filter_taint_tree( taint_tree: List[Dict[str, Any]], frame_predicate: Callable[[str, Dict[str, Any]], bool], ) -> List[Dict[str, Any]]: new_taint_tree = [] for taint in taint_tree: caller_port = taint["port"] new_local_taints = [] for local_taint in taint["taint"]: new_kinds = [ frame for frame in local_taint["kinds"] if frame_predicate(caller_port, frame) ] if len(new_kinds) > 0: new_local_taint = local_taint.copy() new_local_taint["kinds"] = new_kinds new_local_taints.append(new_local_taint) if len(new_local_taints) > 0: new_taint = taint.copy() new_taint["taint"] = new_local_taints new_taint_tree.append(new_taint) return new_taint_tree def filter_model( model: Dict[str, Any], frame_predicate: Callable[[str, Dict[str, Any]], bool] ) -> Dict[str, Any]: model = model.copy() model["sources"] = _filter_taint_tree(model.get("sources", []), frame_predicate) model["sinks"] = _filter_taint_tree(model.get("sinks", []), frame_predicate) model["tito"] = _filter_taint_tree(model.get("tito", []), frame_predicate) return model def filter_model_caller_port(model: Dict[str, Any], port: str) -> Dict[str, Any]: def predicate(caller_port: str, frame: Dict[str, Any]) -> bool: return port == caller_port return filter_model(model, predicate) def filter_model_kind(model: Dict[str, Any], kind: str) -> Dict[str, Any]: def predicate(caller_port: str, frame: Dict[str, Any]) -> bool: return frame["kind"] == kind return filter_model(model, predicate) def _map_taint_tree( taint_tree: List[Dict[str, Any]], frame_map: Callable[[str, Dict[str, Any]], None], local_taint_map: Callable[[str, Dict[str, Any]], None], ) -> List[Dict[str, Any]]: new_taint_tree = [] for taint in taint_tree: caller_port = taint["port"] new_local_taints = [] for local_taint in taint["taint"]: new_kinds = [] for frame in local_taint["kinds"]: new_frame = frame.copy() frame_map(caller_port, new_frame) new_kinds.append(new_frame) new_local_taint = local_taint.copy() new_local_taint["kinds"] = new_kinds local_taint_map(caller_port, new_local_taint) new_local_taints.append(new_local_taint) new_taint = taint.copy() new_taint["taint"] = new_local_taints new_taint_tree.append(new_taint) return new_taint_tree def map_model( model: Dict[str, Any], frame_map: Optional[Callable[[str, Dict[str, Any]], None]] = None, local_taint_map: Optional[Callable[[str, Dict[str, Any]], None]] = None, ) -> Dict[str, Any]: frame_map = frame_map if frame_map is not None else lambda x, y: None local_taint_map = ( local_taint_map if local_taint_map is not None else lambda x, y: None ) model = model.copy() model["sources"] = _map_taint_tree( model.get("sources", []), frame_map, local_taint_map ) model["sinks"] = _map_taint_tree(model.get("sinks", []), frame_map, local_taint_map) model["tito"] = _map_taint_tree(model.get("tito", []), frame_map, local_taint_map) return model def model_remove_tito_positions(model: Dict[str, Any]) -> Dict[str, Any]: def local_taint_map(caller_port: str, local_taint: Dict[str, Any]) -> None: if "tito" in local_taint: del local_taint["tito"] return map_model(model, local_taint_map=local_taint_map) def model_remove_features(model: Dict[str, Any]) -> Dict[str, Any]: def frame_map(caller_port: str, frame: Dict[str, Any]) -> None: if "features" in frame: del frame["features"] def local_taint_map(caller_port: str, local_taint: Dict[str, Any]) -> None: if "local_features" in local_taint: del local_taint["local_features"] return map_model(model, frame_map=frame_map, local_taint_map=local_taint_map) def model_remove_leaf_names(model: Dict[str, Any]) -> Dict[str, Any]: def frame_map(caller_port: str, frame: Dict[str, Any]) -> None: if "leaves" in frame: del frame["leaves"] return map_model(model, frame_map=frame_map) def get_model( callable: str, *, kind: Optional[str] = None, caller_port: Optional[str] = None, remove_sources: bool = False, remove_sinks: bool = False, remove_tito: bool = False, remove_tito_positions: bool = False, remove_features: bool = False, remove_leaf_names: bool = False, ) -> Dict[str, Any]: """Get the model for the given callable.""" _assert_loaded() if callable not in __model_index: raise AssertionError(f"no model for callable `{callable}`.") message = json.loads(_read(__model_index[callable])) assert message["kind"] == "model" model = message["data"] if remove_sources and "sources" in model: del model["sources"] if remove_sinks and "sinks" in model: del model["sinks"] if remove_tito and "tito" in model: del model["tito"] if kind is not None: model = filter_model_kind(model, kind) if caller_port is not None: model = filter_model_caller_port(model, caller_port) if remove_tito_positions: model = model_remove_tito_positions(model) if remove_features: model = model_remove_features(model) if remove_leaf_names: model = model_remove_leaf_names(model) return model def print_json(data: object) -> None: """Pretty print json objects with syntax highlighting.""" if isinstance(data, str): data = json.loads(data) try: subprocess.run(["jq", "-C"], input=json.dumps(data).encode(), check=True) except FileNotFoundError: print(json.dumps(data, indent=" ")) global __warned_missing_jq if not __warned_missing_jq: print( "[HINT] Install `jq` to use syntax highlighting, https://stedolan.github.io/jq/" ) __warned_missing_jq = True def print_model( callable: str, *, kind: Optional[str] = None, caller_port: Optional[str] = None, remove_sources: bool = False, remove_sinks: bool = False, remove_tito: bool = False, remove_tito_positions: bool = False, remove_features: bool = False, remove_leaf_names: bool = False, ) -> None: """ Pretty print the model for the given callable. Optional parameters: kind='UserControlled' Filter by taint kind. caller_port='result' Filter by caller port. remove_sources=False remove_sinks=False remove_tito=False remove_tito_positions=True remove_features=True remove_leaf_names=True """ print_json( get_model( callable, kind=kind, caller_port=caller_port, remove_sources=remove_sources, remove_sinks=remove_sinks, remove_tito=remove_tito, remove_tito_positions=remove_tito_positions, remove_features=remove_features, remove_leaf_names=remove_leaf_names, ) ) def get_issues(callable: str) -> List[Dict[str, Any]]: """Get all issues within the given callable.""" _assert_loaded() issues = [] for position in __issue_index[callable]: message = json.loads(_read(position)) assert message["kind"] == "issue" issues.append(message["data"]) return issues def print_issues(callable: str) -> None: """Pretty print the issues within the given callable.""" print_json(get_issues(callable)) def print_help() -> None: print("# Pysa Model Explorer") print("Available commands:") commands = [ (index, "index('taint-output.json')"), (callables_containing, "callables_containing('foo.bar')"), (callables_matching, "callables_matching(r'foo\\..*')"), (get_model, "get_model('foo.bar')"), (print_model, "print_model('foo.bar')"), (get_issues, "get_issues('foo.bar')"), (print_issues, "print_issues('foo.bar')"), (print_json, "print_json({'a': 'b'})"), ] max_width = max(len(command[1]) for command in commands) for command, example in commands: doc = textwrap.dedent(command.__doc__ or "") doc = textwrap.indent(doc, prefix=" " * (max_width + 3)).strip() print(f" {example:<{max_width}} {doc}") if __name__ == "__main__": print_help()