client/commands/profile.py (343 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import dataclasses import json import logging import subprocess from collections import defaultdict from pathlib import Path from typing import Dict, List, Tuple, Any, Optional, Sequence from typing_extensions import Final from .. import command_arguments, configuration as configuration_module from . import commands, remote_logging, backend_arguments LOG: logging.Logger = logging.getLogger(__name__) PHASE_NAME: str = "phase_name" TRIGGERED_DEPENDENCIES: str = "number_of_triggered_dependencies" @dataclasses.dataclass(frozen=True) class EventMetadata: name: str worker_id: int pid: int timestamp: int tags: Dict[str, str] @dataclasses.dataclass(frozen=True) # pyre-fixme[13]: Attribute `metadata` is never initialized. class Event: metadata: EventMetadata def __init__(self, metadata: EventMetadata) -> None: raise NotImplementedError @dataclasses.dataclass(frozen=True) class DurationEvent(Event): duration: int def add_phase_duration_to_result(self, result: Dict[str, int]) -> None: tags = self.metadata.tags if PHASE_NAME in tags: phase_name = tags[PHASE_NAME] result[phase_name] = self.duration if TRIGGERED_DEPENDENCIES in tags: result[phase_name + ": triggered dependencies"] = int( tags[TRIGGERED_DEPENDENCIES] ) @dataclasses.dataclass(frozen=True) class CounterEvent(Event): description: Final[Optional[str]] def _parse_tags(input: List[List[str]]) -> Dict[str, str]: return {key: value for [key, value] in input} def _parse_metadata(input_json: Dict[str, Any]) -> EventMetadata: pid = input_json["pid"] return EventMetadata( name=input_json["name"], worker_id=input_json.get("worker_id", pid), pid=pid, timestamp=input_json["timestamp"], tags=_parse_tags(input_json.get("tags", [])), ) def parse_event(input_string: str) -> Event: input_json: Dict[str, Any] = json.loads(input_string) event_type = input_json["event_type"] metadata = _parse_metadata(input_json) if event_type[0] == "Duration": duration = event_type[1] return DurationEvent(duration=duration, metadata=metadata) elif event_type[0] == "Counter": description = None if len(event_type) <= 1 else event_type[1] return CounterEvent(description=description, metadata=metadata) else: raise ValueError(f"Unrecognized event type: {input}") def parse_events(input_string: str) -> List[Event]: output: List[Event] = [] for index, line in enumerate(input_string.splitlines()): try: line = line.strip() if len(line) == 0: continue output.append(parse_event(line)) except Exception: raise RuntimeError(f"Malformed log entry detected on line {index + 1}") return output class StatisticsOverTime: _data: List[Tuple[str, int]] = [] def add(self, line: str) -> None: dividers = [ " MEMORY Shared memory size (size: ", " MEMORY Shared memory size post-typecheck (size: ", ] for divider in dividers: if divider in line: time, size_component = line.split(divider) size_in_megabytes = int(size_component[:-2]) size_in_bytes = size_in_megabytes * (10 ** 6) self._data.append((time, size_in_bytes)) def graph_total_shared_memory_size_over_time(self) -> None: try: gnuplot = subprocess.Popen(["gnuplot"], stdin=subprocess.PIPE) # pyre-fixme[16]: `Optional` has no attribute `write`. gnuplot.stdin.write(b"set term dumb 140 25\n") gnuplot.stdin.write(b"plot '-' using 1:2 title '' with linespoints \n") for (i, (_time, size)) in enumerate(self._data): # This is graphing size against # of updates, not time gnuplot.stdin.write(b"%f %f\n" % (i, size)) gnuplot.stdin.write(b"e\n") # pyre-fixme[16]: `Optional` has no attribute `flush`. gnuplot.stdin.flush() except FileNotFoundError: LOG.error("gnuplot is not installed") def to_json(self) -> str: return json.dumps(self._data) class TableStatistics: # category -> aggregation -> table name -> value # pyre-ignore: T62493941 _data: Dict[str, Dict[str, Dict[str, str]]] = defaultdict(lambda: defaultdict(dict)) _shared_heap_category: Final = "bytes serialized into shared heap" @staticmethod def sort_by_value(items: List[Tuple[str, str]]) -> None: def parse(number: str) -> float: if number[-1] == "G": return float(number[:-1]) * (10 ** 9) if number[-1] == "M": return float(number[:-1]) * (10 ** 6) if number[-1] == "K": return float(number[:-1]) * (10 ** 3) return float(number) items.sort(key=lambda x: parse(x[1]), reverse=True) def add(self, line: str) -> None: divider = "stats -- " if divider in line: header, data = line.split(divider) cells = data[:-2].split(", ") collected = [cell.split(": ") for cell in cells] tag_and_category = header[:-2].split(" (") if len(tag_and_category) == 2: tag, category = tag_and_category elif header[:3] == "ALL": tag = "ALL" category = header[4:-1] elif header[:4] == "(ALL": tag = "ALL" category = header[5:-2] else: return if len(tag) > 0: for key, value in collected: self._data[category][key][tag] = value def is_empty(self) -> bool: return len(self._data) == 0 def get_totals(self) -> List[Tuple[str, str]]: totals = list(self._data[self._shared_heap_category]["total"].items()) TableStatistics.sort_by_value(totals) return totals def get_counts(self) -> List[Tuple[str, str]]: counts = list(self._data[self._shared_heap_category]["samples"].items()) TableStatistics.sort_by_value(counts) return counts def _get_server_log(log_directory: Path) -> Path: server_stderr_path = log_directory / "new_server" / "server.stderr" if not server_stderr_path.is_file(): raise RuntimeError(f"Cannot find server output at `{server_stderr_path}`.") return server_stderr_path def _collect_memory_statistics_over_time(log_directory: Path) -> StatisticsOverTime: server_log = _get_server_log(log_directory) extracted = StatisticsOverTime() with open(server_log) as server_log_file: for line in server_log_file.readlines(): extracted.add(line) return extracted def _read_profiling_events(log_directory: Path) -> List[Event]: profiling_output = backend_arguments.get_profiling_log_path(log_directory) if not profiling_output.is_file(): raise RuntimeError( f"Cannot find profiling output at `{profiling_output}`. " + "Please run Pyre with `--enable-profiling` or " + "`--enable-memory-profiling` option first." ) return parse_events(profiling_output.read_text()) def to_traceevents(events: Sequence[Event]) -> List[Dict[str, Any]]: def to_traceevent(event: Event) -> Optional[Dict[str, Any]]: if isinstance(event, DurationEvent): duration_us = event.duration start_time_us = event.metadata.timestamp - duration_us return { "pid": event.metadata.worker_id, "tid": event.metadata.pid, "ts": start_time_us, "ph": "X", "name": event.metadata.name, "dur": duration_us, "args": event.metadata.tags, } elif isinstance(event, CounterEvent): timestamp_us = event.metadata.timestamp arguments: Dict[str, Any] = { key: int(value) for key, value in event.metadata.tags.items() } return { "pid": event.metadata.worker_id, "tid": event.metadata.pid, "ts": timestamp_us, "ph": "C", "name": event.metadata.name, "args": arguments, } else: return None return [ trace_event for trace_event in map(to_traceevent, events) if trace_event is not None ] def split_pre_and_post_initialization( events: Sequence[Event], ) -> Tuple[Sequence[Event], Sequence[Event]]: initialization_point = next( ( index + 1 for index, event in enumerate(events) if event.metadata.name == "initialization" ), len(events), ) return events[:initialization_point], events[initialization_point:] def to_cold_start_phases(events: Sequence[Event]) -> Dict[str, int]: result: Dict[str, int] = {} pre_initialization_events, _ = split_pre_and_post_initialization(events) for event in pre_initialization_events: if not isinstance(event, DurationEvent): continue event.add_phase_duration_to_result(result) if event.metadata.name == "initialization": result["total"] = event.duration return result def to_incremental_updates(events: Sequence[Event]) -> List[Dict[str, int]]: results: List[Dict[str, int]] = [] current: Dict[str, int] = {} _, post_initialization_events = split_pre_and_post_initialization(events) for event in post_initialization_events: if not isinstance(event, DurationEvent): continue event.add_phase_duration_to_result(current) if event.metadata.name == "incremental check": current["total"] = event.duration results.append(current) current = {} return results def to_taint(events: Sequence[Event]) -> Dict[str, int]: result: Dict[str, int] = {} for event in events: if not isinstance(event, DurationEvent): continue event.add_phase_duration_to_result(result) fixpoint_events = [ event for event in events if isinstance(event, DurationEvent) and event.metadata.tags.get(PHASE_NAME) == "Static analysis fixpoint" ] if len(fixpoint_events) == 0: return result for name, value in fixpoint_events[-1].metadata.tags.items(): if name == PHASE_NAME: continue result[name.capitalize()] = int(value) return result def print_individual_table_sizes( configuration: configuration_module.Configuration, ) -> None: server_log = _get_server_log(Path(configuration.log_directory)) extracted = TableStatistics() with open(str(server_log)) as server_log_file: for line in server_log_file.readlines(): extracted.add(line) if extracted.is_empty(): raise RuntimeError( "Cannot find table size data in " + f"`{server_log.as_posix()}`. " + "Please run Pyre with `--debug` option first." ) sizes = json.dumps(extracted.get_totals()) counts = json.dumps(extracted.get_counts()) # I manually put together this json in order to be # simultaneously machine and human readable combined = ( "{\n" + f' "total_table_sizes": {sizes},\n' + f' "table_key_counts": {counts}\n' + "}" ) print(combined) def print_total_shared_memory_size_over_time( configuration: configuration_module.Configuration, ) -> None: memory_over_time = _collect_memory_statistics_over_time( Path(configuration.log_directory) ).to_json() print(memory_over_time) def print_total_shared_memory_size_over_time_graph( configuration: configuration_module.Configuration, ) -> None: statistics = _collect_memory_statistics_over_time(Path(configuration.log_directory)) statistics.graph_total_shared_memory_size_over_time() def print_trace_event( configuration: configuration_module.Configuration, ) -> None: events = _read_profiling_events(Path(configuration.log_directory)) print(json.dumps(to_traceevents(events))) def print_cold_start_phases( configuration: configuration_module.Configuration, ) -> None: events = _read_profiling_events(Path(configuration.log_directory)) print(json.dumps(to_cold_start_phases(events), indent=2)) def print_incremental_updates( configuration: configuration_module.Configuration, ) -> None: events = _read_profiling_events(Path(configuration.log_directory)) print(json.dumps(to_incremental_updates(events), indent=2)) def print_taint( configuration: configuration_module.Configuration, ) -> None: events = _read_profiling_events(Path(configuration.log_directory)) print(json.dumps(to_taint(events), indent=2)) def run_profile( configuration: configuration_module.Configuration, output: command_arguments.ProfileOutput, ) -> commands.ExitCode: if output == command_arguments.ProfileOutput.INDIVIDUAL_TABLE_SIZES: print_individual_table_sizes(configuration) elif output == command_arguments.ProfileOutput.TOTAL_SHARED_MEMORY_SIZE_OVER_TIME: print_total_shared_memory_size_over_time(configuration) elif ( output == command_arguments.ProfileOutput.TOTAL_SHARED_MEMORY_SIZE_OVER_TIME_GRAPH ): print_total_shared_memory_size_over_time_graph(configuration) elif output == command_arguments.ProfileOutput.TRACE_EVENT: print_trace_event(configuration) elif output == command_arguments.ProfileOutput.COLD_START_PHASES: print_cold_start_phases(configuration) elif output == command_arguments.ProfileOutput.INCREMENTAL_UPDATES: print_incremental_updates(configuration) elif output == command_arguments.ProfileOutput.TAINT: print_taint(configuration) else: raise RuntimeError(f"Unrecognized output format: {output}") return commands.ExitCode.SUCCESS @remote_logging.log_usage(command_name="profile") def run( configuration: configuration_module.Configuration, output: command_arguments.ProfileOutput, ) -> commands.ExitCode: try: return run_profile(configuration, output) except Exception as error: raise commands.ClientException( f"Exception occurred during profile: {error}" ) from error