sapp/trace_graph.py (422 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # pyre-strict import json import logging from collections import defaultdict from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple, Any from .bulk_saver import BulkSaver from .models import ( DBID, SHARED_TEXT_LENGTH, LeafMapping, Issue, IssueInstance, IssueInstanceFixInfo, SharedText, SharedTextKind, TraceFrame, TraceFrameAnnotation, TraceKind, Feature, ) log: logging.Logger = logging.getLogger("sapp") LeafIDToDepthMap = Dict[int, Optional[int]] class TraceGraph(object): """Represents a graph of the Zoncolan trace steps. Nodes of the graph are the issues, preconditions, postconditions, sources and sinks. Edges are the the assocs, and for pre/postconditions, the map of 'caller->callee' gives a one direction edge to the next pre/postcondition, and the map of 'callee->caller' gives the reverse edge. """ def __init__(self) -> None: self._issues: Dict[int, Issue] = {} self._issue_instances: Dict[int, IssueInstance] = {} self._trace_annotations: Dict[int, TraceFrameAnnotation] = {} # Create a mapping of (caller_id, caller_port) to the corresponding # trace frame's id. self._trace_frames_map: DefaultDict[ TraceKind, DefaultDict[Tuple[int, str], Set[int]] ] = defaultdict(lambda: defaultdict(set)) # Similar to _trace_frames_map, but maps the reverse direction # of the trace graph, i.e. (callee_id, callee_port) to the # trace_frame_id. self._trace_frames_rev_map: DefaultDict[ TraceKind, DefaultDict[Tuple[int, str], Set[int]] ] = defaultdict(lambda: defaultdict(set)) self._trace_frames: Dict[int, TraceFrame] = {} self._features: Dict[int, Feature] = {} self._feature_lookup: Dict[str, int] = {} self._shared_texts: Dict[int, SharedText] = {} self._shared_text_lookup: ( DefaultDict[SharedTextKind, Dict[str, int]] ) = defaultdict(dict) self._trace_frame_leaf_assoc: DefaultDict[int, LeafIDToDepthMap] = defaultdict( lambda: {} ) self._trace_frame_issue_instance_assoc: DefaultDict[ int, Set[int] ] = defaultdict(set) self._issue_instance_trace_frame_assoc: DefaultDict[ int, Set[int] ] = defaultdict(set) self._trace_frame_annotation_trace_frame_assoc: DefaultDict[ int, Set[int] ] = defaultdict(set) self._trace_frame_trace_frame_annotation_assoc: DefaultDict[ int, Set[int] ] = defaultdict(set) self._issue_instance_shared_text_assoc: DefaultDict[ int, Set[int] ] = defaultdict(set) self._shared_text_issue_instance_assoc: DefaultDict[ int, Set[int] ] = defaultdict(set) self._issue_instance_feature_assoc: DefaultDict[int, Set[int]] = defaultdict( set ) self._feature_issue_instance_assoc: DefaultDict[int, Set[int]] = defaultdict( set ) self._issue_instance_fix_info: Dict[int, IssueInstanceFixInfo] = {} # !!!!! IMPORTANT !!!!! # IF YOU ARE ADDING MORE FIELDS/EDGES TO THIS GRAPH, CHECK IF # TrimmedTraceGraph NEEDS TO BE UPDATED AS WELL. # # TrimmedTraceGraph will populate itself from this object. It searches # TrimmedGraph for nodes and edges of 'affected_files' and copies them # over. If new fields/edges are added, these may need to be copied in # TrimmedTraceGraph as well. def add_issue(self, issue: Issue) -> None: assert issue.id.local_id not in self._issues, "Issue already exists" self._issues[issue.id.local_id] = issue def get_issue(self, issue_id: DBID) -> Issue: return self._issues[issue_id.local_id] def get_issues(self) -> Iterable[Issue]: return (issue for issue in self._issues.values()) def get_number_issues(self) -> int: return len(self._issues) def add_issue_instance(self, instance: IssueInstance) -> None: assert ( instance.id.local_id not in self._issue_instances ), "Instance already exists" self._issue_instances[instance.id.local_id] = instance def get_issue_instances(self) -> Iterable[IssueInstance]: return (instance for instance in self._issue_instances.values()) def add_issue_instance_fix_info( self, instance: IssueInstance, fix_info: IssueInstanceFixInfo ) -> None: assert ( instance.id.local_id not in self._issue_instance_fix_info ), "Instance fix info already exists" self._issue_instance_fix_info[instance.id.local_id] = fix_info def get_text(self, shared_text_id: DBID) -> str: return self._shared_texts[shared_text_id.local_id].contents def get_shared_text_by_local_id(self, shared_text_id: int) -> SharedText: return self._shared_texts[shared_text_id] def get_shared_text( self, kind: SharedTextKind, content: str ) -> Optional[SharedText]: if kind in self._shared_text_lookup: contents = self._shared_text_lookup[kind] if content in contents and contents[content] in self._shared_texts: return self._shared_texts[contents[content]] return None def has_trace_frames_with_caller( self, kind: TraceKind, caller_id: DBID, caller_port: str ) -> bool: if self._trace_frames_map[kind]: key = (caller_id.local_id, caller_port) return key in self._trace_frames_map[kind] else: return False def has_postconditions_with_caller(self, caller_id: DBID, caller_port: str) -> bool: return self.has_trace_frames_with_caller( TraceKind.postcondition, caller_id, caller_port ) def has_preconditions_with_caller(self, caller_id: DBID, caller_port: str) -> bool: return self.has_trace_frames_with_caller( TraceKind.precondition, caller_id, caller_port ) def add_trace_annotation(self, annotation: TraceFrameAnnotation) -> None: self._trace_annotations[annotation.id.local_id] = annotation def get_condition_annotations(self, cond_id: int) -> List[TraceFrameAnnotation]: return [ t for t in self._trace_annotations.values() if t.trace_frame_id.local_id == cond_id ] def get_annotation_trace_frames(self, ann_id: int) -> List[TraceFrame]: if ann_id in self._trace_frame_annotation_trace_frame_assoc: return [ self.get_trace_frame_from_id(tf_id) for tf_id in self._trace_frame_annotation_trace_frame_assoc[ann_id] ] else: return [] def add_trace_frame(self, trace_frame: TraceFrame) -> None: key = (trace_frame.caller_id.local_id, trace_frame.caller_port) rev_key = (trace_frame.callee_id.local_id, trace_frame.callee_port) # pyre-fixme[6]: Expected `TraceKind` for 1st param but got `str`. self._trace_frames_map[trace_frame.kind][key].add(trace_frame.id.local_id) # pyre-fixme[6]: Expected `TraceKind` for 1st param but got `str`. self._trace_frames_rev_map[trace_frame.kind][rev_key].add( trace_frame.id.local_id ) self._trace_frames[trace_frame.id.local_id] = trace_frame def get_trace_frames_from_caller( self, kind: TraceKind, caller_id: DBID, caller_port: str ) -> List[TraceFrame]: key = (caller_id.local_id, caller_port) return [ self._trace_frames[trace_frame_id] for trace_frame_id in self._trace_frames_map[kind][key] ] def get_trace_frame_from_id(self, id: int) -> TraceFrame: return self._trace_frames[id] def add_shared_text(self, shared_text: SharedText) -> None: assert ( shared_text.id.local_id not in self._shared_texts ), "Shared text already exists" assert ( shared_text.kind not in self._shared_text_lookup or shared_text.contents not in self._shared_text_lookup[shared_text.kind] ), "Shared text with same kind, contents exists" self._shared_texts[shared_text.id.local_id] = shared_text # Allow look up of SharedTexts by name and kind (to optimize # get_shared_text which is called when parsing each issue instance) self._shared_text_lookup[shared_text.kind][ shared_text.contents ] = shared_text.id.local_id def add_feature(self, bc: Feature) -> None: self._features[bc.id.local_id] = bc self._feature_lookup[json.dumps(bc.data, sort_keys=True)] = bc.id.local_id def get_feature(self, feature: Dict[str, Any]) -> Optional[Feature]: feature_serialized = json.dumps(feature, sort_keys=True) if feature_serialized in self._feature_lookup: if self._feature_lookup[feature_serialized] in self._features: return self._features[self._feature_lookup[feature_serialized]] return None def get_or_add_feature(self, feature: Dict[str, Any]) -> Feature: feature_object = self.get_feature(feature) if feature_object is None: feature_object = Feature.Record( id=DBID(), data=feature, ) self.add_feature(feature_object) return feature_object def get_or_add_shared_text(self, kind: SharedTextKind, name: str) -> SharedText: name = name[:SHARED_TEXT_LENGTH] shared_text = self.get_shared_text(kind, name) if shared_text is None: shared_text = SharedText.Record(id=DBID(), contents=name, kind=kind) self.add_shared_text(shared_text) return shared_text def add_trace_frame_leaf_assoc( self, trace_frame: TraceFrame, leaf: SharedText, depth: Optional[int] ) -> None: self._trace_frame_leaf_assoc[trace_frame.id.local_id][leaf.id.local_id] = depth def add_trace_frame_leaf_by_local_id_assoc( self, trace_frame: TraceFrame, leaf_id: int, depth: Optional[int] ) -> None: self._trace_frame_leaf_assoc[trace_frame.id.local_id][leaf_id] = depth def get_trace_frame_leaf_ids(self, trace_frame: TraceFrame) -> Set[int]: ids: Set[int] = { id for id in self._trace_frame_leaf_assoc[trace_frame.id.local_id] } return ids def get_trace_frame_leaf_ids_by_kind( self, trace_frame: TraceFrame, kind: SharedTextKind ) -> Set[int]: return { id for id in self._trace_frame_leaf_assoc[trace_frame.id.local_id] if self._shared_texts[id].kind == kind } def get_trace_frame_leaf_ids_with_depths( self, trace_frame: TraceFrame ) -> LeafIDToDepthMap: return self._trace_frame_leaf_assoc[trace_frame.id.local_id] def add_issue_instance_trace_frame_assoc( self, instance: IssueInstance, trace_frame: TraceFrame ) -> None: self._issue_instance_trace_frame_assoc[instance.id.local_id].add( trace_frame.id.local_id ) self._trace_frame_issue_instance_assoc[trace_frame.id.local_id].add( instance.id.local_id ) def add_trace_frame_annotation_trace_frame_assoc( self, annotation: TraceFrameAnnotation, trace_frame: TraceFrame ) -> None: self._trace_frame_annotation_trace_frame_assoc[annotation.id.local_id].add( trace_frame.id.local_id ) self._trace_frame_trace_frame_annotation_assoc[trace_frame.id.local_id].add( annotation.id.local_id ) def get_issue_instance_trace_frames( self, instance: IssueInstance ) -> List[TraceFrame]: if instance.id.local_id in self._issue_instance_trace_frame_assoc: return [ self.get_trace_frame_from_id(id) for id in self._issue_instance_trace_frame_assoc[instance.id.local_id] ] else: return [] def get_next_trace_frames(self, trace_frame: TraceFrame) -> Iterable[TraceFrame]: return self.get_trace_frames_from_caller( # pyre-fixme[6]: Expected `TraceKind` for 1st param but got `str`. trace_frame.kind, trace_frame.callee_id, trace_frame.callee_port, ) def add_issue_instance_shared_text_assoc_id( self, instance: IssueInstance, shared_text_id: int ) -> None: self._issue_instance_shared_text_assoc[instance.id.local_id].add(shared_text_id) self._shared_text_issue_instance_assoc[shared_text_id].add(instance.id.local_id) def add_issue_instance_shared_text_assoc( self, instance: IssueInstance, shared_text: SharedText ) -> None: self.add_issue_instance_shared_text_assoc_id(instance, shared_text.id.local_id) def add_issue_instance_feature_assoc_id( self, instance: IssueInstance, feature_id: int ) -> None: self._issue_instance_feature_assoc[instance.id.local_id].add(feature_id) self._feature_issue_instance_assoc[feature_id].add(instance.id.local_id) def add_issue_instance_feature_assoc( self, instance: IssueInstance, feature: Feature ) -> None: self.add_issue_instance_feature_assoc_id(instance, feature.id.local_id) def get_issue_instance_shared_texts( self, instance_id: int, kind: SharedTextKind ) -> List[SharedText]: return [ self._shared_texts[msg_id] for msg_id in self._issue_instance_shared_text_assoc[instance_id] if self._shared_texts[msg_id].kind == kind ] def update_bulk_saver(self, bulk_saver: BulkSaver) -> None: bulk_saver.add_all(list(self._issues.values())) bulk_saver.add_all(list(self._issue_instances.values())) bulk_saver.add_all(list(self._trace_frames.values())) bulk_saver.add_all(list(self._issue_instance_fix_info.values())) bulk_saver.add_all(list(self._trace_annotations.values())) bulk_saver.add_all(list(self._shared_texts.values())) bulk_saver.add_all(list(self._features.values())) self._save_issue_instance_trace_frame_assoc(bulk_saver) self._save_trace_frame_leaf_assoc(bulk_saver) self._save_issue_instance_shared_text_assoc(bulk_saver) self._save_trace_frame_annotation_trace_frame_assoc(bulk_saver) self._save_issue_instance_feature_assoc(bulk_saver) def _save_issue_instance_trace_frame_assoc(self, bulk_saver: BulkSaver) -> None: for ( trace_frame_id, instance_ids, ) in self._trace_frame_issue_instance_assoc.items(): for instance_id in instance_ids: bulk_saver.add_issue_instance_trace_frame_assoc( self._issue_instances[instance_id], self._trace_frames[trace_frame_id], ) def _save_issue_instance_feature_assoc(self, bulk_saver: BulkSaver) -> None: for ( feature_id, instance_ids, ) in self._feature_issue_instance_assoc.items(): for instance_id in instance_ids: bulk_saver.add_issue_instance_feature_assoc( self._issue_instances[instance_id], self._features[feature_id], ) def _save_trace_frame_annotation_trace_frame_assoc( self, bulk_saver: BulkSaver ) -> None: for ( trace_annotation_id, trace_frame_ids, ) in self._trace_frame_annotation_trace_frame_assoc.items(): for trace_frame_id in trace_frame_ids: bulk_saver.add_trace_frame_annotation_trace_frame_assoc( self._trace_annotations[trace_annotation_id], self._trace_frames[trace_frame_id], ) def _save_trace_frame_leaf_assoc(self, bulk_saver: BulkSaver) -> None: """Adds trace frame leaf assocs to bulk saver after filtering them: 1. if frame is a leaf, include all kinds 2. otherwise, find outgoing leaf kinds and intersect with union of incoming leaf kinds of all successor frames. 3. include only kinds that map to one of these outgoing kinds. """ for trace_frame_id, leaf_ids in self._trace_frame_leaf_assoc.items(): frame = self._trace_frames[trace_frame_id] valid_frame_leaf_ids = self._compute_valid_frame_leaves(frame) for leaf_id, depth in leaf_ids.items(): leaf_text = self._shared_texts[leaf_id] if ( leaf_text.kind is SharedTextKind.FEATURE or leaf_id in valid_frame_leaf_ids or self._is_opposite_leaf(frame, leaf_text) ): bulk_saver.add_trace_frame_leaf_assoc(leaf_text, frame, depth) else: # Logging all the leaf kinds that are omitted causes large logs. pass def _is_opposite_leaf(self, frame: TraceFrame, leaf: SharedText) -> bool: """We may be propagating sources along sink traces or vice versa. These should not be filtered and are identified here.""" return ( frame.kind == TraceKind.PRECONDITION and leaf.kind == SharedTextKind.SOURCE ) or ( frame.kind == TraceKind.POSTCONDITION and leaf.kind == SharedTextKind.SINK ) def _compute_valid_frame_leaves(self, frame: TraceFrame) -> Set[int]: leaf_mapping = frame.leaf_mapping is_leaf_frame = self.is_leaf_port(frame.callee_port) if not is_leaf_frame: callee_frames = self.get_next_trace_frames(frame) callee_leaf_ids = { callee_map.caller_leaf for callee_frame in callee_frames for callee_map in callee_frame.leaf_mapping } else: callee_leaf_ids = set() return { leaf_map.transform for leaf_map in leaf_mapping if is_leaf_frame or (leaf_map.callee_leaf in callee_leaf_ids) } def is_leaf_port(self, port: str) -> bool: return ( port == "leaf" or port == "source" or port == "sink" or port.startswith("anchor:") or port.startswith("producer:") ) def _save_issue_instance_shared_text_assoc(self, bulk_saver: BulkSaver) -> None: for ( shared_text_id, instance_ids, ) in self._shared_text_issue_instance_assoc.items(): for instance_id in instance_ids: bulk_saver.add_issue_instance_shared_text_assoc( self._issue_instances[instance_id], self._shared_texts[shared_text_id], ) def compute_next_leaf_kinds( self, leaves: Set[int], leaf_mapping: Set[LeafMapping] ) -> Set[int]: """Normally, we would just intersect leaves and frame leaves, but since frame leaves can indicate local transforms of the form T1:...:Tn@G1..Gm:S, we need to be more careful. We first need to identify which frame leaves match by substituting the @ for :, in general that would be T1:..Tn:G1..Gm:S. Then given these matches, erase everything up to and including the @ sign. That will be the new leaf kind. In general, that is G1..Gm:S. leaf_mapping is already normalized to (caller_leaf_id, callee_leaf_id), i.e. which callers map to which caller ids obtained by performing the substitutions. For non-transform kinds, the leaf mapping contains identical caller_leaf_id, callee_leaf_id. """ next_kinds = set() for leaf_map in leaf_mapping: if leaf_map.caller_leaf in leaves: next_kinds.add(leaf_map.callee_leaf) return next_kinds def compute_prev_leaf_kinds( self, leaves: Set[int], leaf_mapping: Set[LeafMapping] ) -> Set[int]: """Same as next_leaf_kinds but when following from leaves to issues.""" next_kinds = set() for leaf_map in leaf_mapping: if leaf_map.callee_leaf in leaves: next_kinds.add(leaf_map.caller_leaf) return next_kinds def get_transform_normalized_leaf(self, leaf: str) -> str: return leaf.replace("@", ":", 1) def get_transform_normalized_kind_id(self, leaf_kind: SharedText) -> int: assert ( leaf_kind.kind == SharedTextKind.SINK or leaf_kind.kind == SharedTextKind.SOURCE ) if "@" in leaf_kind.contents: normal_name = self.get_transform_normalized_leaf(leaf_kind.contents) normal_kind = self.get_or_add_shared_text(leaf_kind.kind, normal_name) return normal_kind.id.local_id else: return leaf_kind.id.local_id def get_transformed_kind_id(self, leaf_kind: SharedText) -> int: assert ( leaf_kind.kind == SharedTextKind.SINK or leaf_kind.kind == SharedTextKind.SOURCE ) if "@" in leaf_kind.contents: splits = leaf_kind.contents.split("@", 1) remaining_kind = self.get_or_add_shared_text(leaf_kind.kind, splits[1]) return remaining_kind.id.local_id else: return leaf_kind.id.local_id def get_callee_leaf_kinds_of_frame(self, trace_frame: TraceFrame) -> Set[int]: """Get leaf kinds expected by frames starting with the callee. In case of transforms, this is the untransformed kinds. """ leaf_mapping = trace_frame.leaf_mapping assert leaf_mapping is not None return {leaf_map.callee_leaf for leaf_map in leaf_mapping} def get_caller_leaf_kinds_of_frame(self, trace_frame: TraceFrame) -> Set[int]: """Get leaf kinds expected by frames starting with the caller. In case of transforms, this is the transformed kinds. """ leaf_mapping = trace_frame.leaf_mapping assert leaf_mapping is not None return {leaf_map.caller_leaf for leaf_map in leaf_mapping}