scripts/explore_pysa_models.py (266 lines of code) (raw):
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import collections
import io
import json
import re
import subprocess
import textwrap
from pathlib import Path
from typing import Optional, Any, Dict, List, Tuple, Iterable, NamedTuple, Callable
class FilePosition(NamedTuple):
offset: int
length: int
__handle: Optional[io.BufferedReader] = None
__model_index: Dict[str, FilePosition] = {}
__issue_index: Dict[str, List[FilePosition]] = collections.defaultdict(list)
__warned_missing_jq: bool = False
def _iter_with_offset(lines: Iterable[bytes]) -> Iterable[Tuple[bytes, int]]:
offset = 0
for line in lines:
yield (line, offset)
offset += len(line)
def _resolve_taint_output_path(taint_output_filename: str) -> Path:
taint_output_path = Path(taint_output_filename)
if taint_output_path.is_dir():
taint_output_path = taint_output_path / "taint-output.json"
return taint_output_path
def index(taint_output_filename: str = "taint-output.json") -> None:
"""Index all available models in the given taint output file or directory."""
global __handle, __model_index, __issue_index
print(f"Indexing `{taint_output_filename}`")
taint_output_path = _resolve_taint_output_path(taint_output_filename)
handle = open(taint_output_path, "rb")
__handle = handle
__model_index = {}
__issue_index = collections.defaultdict(list)
count_models = 0
count_issues = 0
for line, offset in _iter_with_offset(handle):
message = json.loads(line)
if "kind" not in message:
continue
if message["kind"] == "model":
callable = message["data"]["callable"]
assert callable not in __model_index
__model_index[callable] = FilePosition(offset=offset, length=len(line))
count_models += 1
elif message["kind"] == "issue":
callable = message["data"]["callable"]
__issue_index[callable].append(
FilePosition(offset=offset, length=len(line))
)
count_issues += 1
print(f"Indexed {count_models} models and {count_issues} issues")
def _assert_loaded() -> io.BufferedReader:
handle = __handle
if handle is None or len(__model_index) == 0:
raise AssertionError("call index() first")
return handle
def callables_containing(string: str) -> List[str]:
"""Find all callables containing the given string."""
_assert_loaded()
return sorted(filter(lambda name: string in name, __model_index.keys()))
def callables_matching(pattern: str) -> List[str]:
"""Find all callables matching the given regular expression."""
_assert_loaded()
regex = re.compile(pattern)
return sorted(filter(lambda name: re.search(regex, name), __model_index.keys()))
def _read(position: FilePosition) -> bytes:
handle = _assert_loaded()
handle.seek(position.offset)
return handle.read(position.length)
def _filter_taint_tree(
taint_tree: List[Dict[str, Any]],
frame_predicate: Callable[[str, Dict[str, Any]], bool],
) -> List[Dict[str, Any]]:
new_taint_tree = []
for taint in taint_tree:
caller_port = taint["port"]
new_local_taints = []
for local_taint in taint["taint"]:
new_kinds = [
frame
for frame in local_taint["kinds"]
if frame_predicate(caller_port, frame)
]
if len(new_kinds) > 0:
new_local_taint = local_taint.copy()
new_local_taint["kinds"] = new_kinds
new_local_taints.append(new_local_taint)
if len(new_local_taints) > 0:
new_taint = taint.copy()
new_taint["taint"] = new_local_taints
new_taint_tree.append(new_taint)
return new_taint_tree
def filter_model(
model: Dict[str, Any], frame_predicate: Callable[[str, Dict[str, Any]], bool]
) -> Dict[str, Any]:
model = model.copy()
model["sources"] = _filter_taint_tree(model.get("sources", []), frame_predicate)
model["sinks"] = _filter_taint_tree(model.get("sinks", []), frame_predicate)
model["tito"] = _filter_taint_tree(model.get("tito", []), frame_predicate)
return model
def filter_model_caller_port(model: Dict[str, Any], port: str) -> Dict[str, Any]:
def predicate(caller_port: str, frame: Dict[str, Any]) -> bool:
return port == caller_port
return filter_model(model, predicate)
def filter_model_kind(model: Dict[str, Any], kind: str) -> Dict[str, Any]:
def predicate(caller_port: str, frame: Dict[str, Any]) -> bool:
return frame["kind"] == kind
return filter_model(model, predicate)
def _map_taint_tree(
taint_tree: List[Dict[str, Any]],
frame_map: Callable[[str, Dict[str, Any]], None],
local_taint_map: Callable[[str, Dict[str, Any]], None],
) -> List[Dict[str, Any]]:
new_taint_tree = []
for taint in taint_tree:
caller_port = taint["port"]
new_local_taints = []
for local_taint in taint["taint"]:
new_kinds = []
for frame in local_taint["kinds"]:
new_frame = frame.copy()
frame_map(caller_port, new_frame)
new_kinds.append(new_frame)
new_local_taint = local_taint.copy()
new_local_taint["kinds"] = new_kinds
local_taint_map(caller_port, new_local_taint)
new_local_taints.append(new_local_taint)
new_taint = taint.copy()
new_taint["taint"] = new_local_taints
new_taint_tree.append(new_taint)
return new_taint_tree
def map_model(
model: Dict[str, Any],
frame_map: Optional[Callable[[str, Dict[str, Any]], None]] = None,
local_taint_map: Optional[Callable[[str, Dict[str, Any]], None]] = None,
) -> Dict[str, Any]:
frame_map = frame_map if frame_map is not None else lambda x, y: None
local_taint_map = (
local_taint_map if local_taint_map is not None else lambda x, y: None
)
model = model.copy()
model["sources"] = _map_taint_tree(
model.get("sources", []), frame_map, local_taint_map
)
model["sinks"] = _map_taint_tree(model.get("sinks", []), frame_map, local_taint_map)
model["tito"] = _map_taint_tree(model.get("tito", []), frame_map, local_taint_map)
return model
def model_remove_tito_positions(model: Dict[str, Any]) -> Dict[str, Any]:
def local_taint_map(caller_port: str, local_taint: Dict[str, Any]) -> None:
if "tito" in local_taint:
del local_taint["tito"]
return map_model(model, local_taint_map=local_taint_map)
def model_remove_features(model: Dict[str, Any]) -> Dict[str, Any]:
def frame_map(caller_port: str, frame: Dict[str, Any]) -> None:
if "features" in frame:
del frame["features"]
def local_taint_map(caller_port: str, local_taint: Dict[str, Any]) -> None:
if "local_features" in local_taint:
del local_taint["local_features"]
return map_model(model, frame_map=frame_map, local_taint_map=local_taint_map)
def model_remove_leaf_names(model: Dict[str, Any]) -> Dict[str, Any]:
def frame_map(caller_port: str, frame: Dict[str, Any]) -> None:
if "leaves" in frame:
del frame["leaves"]
return map_model(model, frame_map=frame_map)
def get_model(
callable: str,
*,
kind: Optional[str] = None,
caller_port: Optional[str] = None,
remove_sources: bool = False,
remove_sinks: bool = False,
remove_tito: bool = False,
remove_tito_positions: bool = False,
remove_features: bool = False,
remove_leaf_names: bool = False,
) -> Dict[str, Any]:
"""Get the model for the given callable."""
_assert_loaded()
if callable not in __model_index:
raise AssertionError(f"no model for callable `{callable}`.")
message = json.loads(_read(__model_index[callable]))
assert message["kind"] == "model"
model = message["data"]
if remove_sources and "sources" in model:
del model["sources"]
if remove_sinks and "sinks" in model:
del model["sinks"]
if remove_tito and "tito" in model:
del model["tito"]
if kind is not None:
model = filter_model_kind(model, kind)
if caller_port is not None:
model = filter_model_caller_port(model, caller_port)
if remove_tito_positions:
model = model_remove_tito_positions(model)
if remove_features:
model = model_remove_features(model)
if remove_leaf_names:
model = model_remove_leaf_names(model)
return model
def print_json(data: object) -> None:
"""Pretty print json objects with syntax highlighting."""
if isinstance(data, str):
data = json.loads(data)
try:
subprocess.run(["jq", "-C"], input=json.dumps(data).encode(), check=True)
except FileNotFoundError:
print(json.dumps(data, indent=" "))
global __warned_missing_jq
if not __warned_missing_jq:
print(
"[HINT] Install `jq` to use syntax highlighting, https://stedolan.github.io/jq/"
)
__warned_missing_jq = True
def print_model(
callable: str,
*,
kind: Optional[str] = None,
caller_port: Optional[str] = None,
remove_sources: bool = False,
remove_sinks: bool = False,
remove_tito: bool = False,
remove_tito_positions: bool = False,
remove_features: bool = False,
remove_leaf_names: bool = False,
) -> None:
"""
Pretty print the model for the given callable.
Optional parameters:
kind='UserControlled' Filter by taint kind.
caller_port='result' Filter by caller port.
remove_sources=False
remove_sinks=False
remove_tito=False
remove_tito_positions=True
remove_features=True
remove_leaf_names=True
"""
print_json(
get_model(
callable,
kind=kind,
caller_port=caller_port,
remove_sources=remove_sources,
remove_sinks=remove_sinks,
remove_tito=remove_tito,
remove_tito_positions=remove_tito_positions,
remove_features=remove_features,
remove_leaf_names=remove_leaf_names,
)
)
def get_issues(callable: str) -> List[Dict[str, Any]]:
"""Get all issues within the given callable."""
_assert_loaded()
issues = []
for position in __issue_index[callable]:
message = json.loads(_read(position))
assert message["kind"] == "issue"
issues.append(message["data"])
return issues
def print_issues(callable: str) -> None:
"""Pretty print the issues within the given callable."""
print_json(get_issues(callable))
def print_help() -> None:
print("# Pysa Model Explorer")
print("Available commands:")
commands = [
(index, "index('taint-output.json')"),
(callables_containing, "callables_containing('foo.bar')"),
(callables_matching, "callables_matching(r'foo\\..*')"),
(get_model, "get_model('foo.bar')"),
(print_model, "print_model('foo.bar')"),
(get_issues, "get_issues('foo.bar')"),
(print_issues, "print_issues('foo.bar')"),
(print_json, "print_json({'a': 'b'})"),
]
max_width = max(len(command[1]) for command in commands)
for command, example in commands:
doc = textwrap.dedent(command.__doc__ or "")
doc = textwrap.indent(doc, prefix=" " * (max_width + 3)).strip()
print(f" {example:<{max_width}} {doc}")
if __name__ == "__main__":
print_help()