client/commands/profile.py (343 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import json
import logging
import subprocess
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional, Sequence
from typing_extensions import Final
from .. import command_arguments, configuration as configuration_module
from . import commands, remote_logging, backend_arguments
LOG: logging.Logger = logging.getLogger(__name__)
PHASE_NAME: str = "phase_name"
TRIGGERED_DEPENDENCIES: str = "number_of_triggered_dependencies"
@dataclasses.dataclass(frozen=True)
class EventMetadata:
name: str
worker_id: int
pid: int
timestamp: int
tags: Dict[str, str]
@dataclasses.dataclass(frozen=True)
# pyre-fixme[13]: Attribute `metadata` is never initialized.
class Event:
metadata: EventMetadata
def __init__(self, metadata: EventMetadata) -> None:
raise NotImplementedError
@dataclasses.dataclass(frozen=True)
class DurationEvent(Event):
duration: int
def add_phase_duration_to_result(self, result: Dict[str, int]) -> None:
tags = self.metadata.tags
if PHASE_NAME in tags:
phase_name = tags[PHASE_NAME]
result[phase_name] = self.duration
if TRIGGERED_DEPENDENCIES in tags:
result[phase_name + ": triggered dependencies"] = int(
tags[TRIGGERED_DEPENDENCIES]
)
@dataclasses.dataclass(frozen=True)
class CounterEvent(Event):
description: Final[Optional[str]]
def _parse_tags(input: List[List[str]]) -> Dict[str, str]:
return {key: value for [key, value] in input}
def _parse_metadata(input_json: Dict[str, Any]) -> EventMetadata:
pid = input_json["pid"]
return EventMetadata(
name=input_json["name"],
worker_id=input_json.get("worker_id", pid),
pid=pid,
timestamp=input_json["timestamp"],
tags=_parse_tags(input_json.get("tags", [])),
)
def parse_event(input_string: str) -> Event:
input_json: Dict[str, Any] = json.loads(input_string)
event_type = input_json["event_type"]
metadata = _parse_metadata(input_json)
if event_type[0] == "Duration":
duration = event_type[1]
return DurationEvent(duration=duration, metadata=metadata)
elif event_type[0] == "Counter":
description = None if len(event_type) <= 1 else event_type[1]
return CounterEvent(description=description, metadata=metadata)
else:
raise ValueError(f"Unrecognized event type: {input}")
def parse_events(input_string: str) -> List[Event]:
output: List[Event] = []
for index, line in enumerate(input_string.splitlines()):
try:
line = line.strip()
if len(line) == 0:
continue
output.append(parse_event(line))
except Exception:
raise RuntimeError(f"Malformed log entry detected on line {index + 1}")
return output
class StatisticsOverTime:
_data: List[Tuple[str, int]] = []
def add(self, line: str) -> None:
dividers = [
" MEMORY Shared memory size (size: ",
" MEMORY Shared memory size post-typecheck (size: ",
]
for divider in dividers:
if divider in line:
time, size_component = line.split(divider)
size_in_megabytes = int(size_component[:-2])
size_in_bytes = size_in_megabytes * (10 ** 6)
self._data.append((time, size_in_bytes))
def graph_total_shared_memory_size_over_time(self) -> None:
try:
gnuplot = subprocess.Popen(["gnuplot"], stdin=subprocess.PIPE)
# pyre-fixme[16]: `Optional` has no attribute `write`.
gnuplot.stdin.write(b"set term dumb 140 25\n")
gnuplot.stdin.write(b"plot '-' using 1:2 title '' with linespoints \n")
for (i, (_time, size)) in enumerate(self._data):
# This is graphing size against # of updates, not time
gnuplot.stdin.write(b"%f %f\n" % (i, size))
gnuplot.stdin.write(b"e\n")
# pyre-fixme[16]: `Optional` has no attribute `flush`.
gnuplot.stdin.flush()
except FileNotFoundError:
LOG.error("gnuplot is not installed")
def to_json(self) -> str:
return json.dumps(self._data)
class TableStatistics:
# category -> aggregation -> table name -> value
# pyre-ignore: T62493941
_data: Dict[str, Dict[str, Dict[str, str]]] = defaultdict(lambda: defaultdict(dict))
_shared_heap_category: Final = "bytes serialized into shared heap"
@staticmethod
def sort_by_value(items: List[Tuple[str, str]]) -> None:
def parse(number: str) -> float:
if number[-1] == "G":
return float(number[:-1]) * (10 ** 9)
if number[-1] == "M":
return float(number[:-1]) * (10 ** 6)
if number[-1] == "K":
return float(number[:-1]) * (10 ** 3)
return float(number)
items.sort(key=lambda x: parse(x[1]), reverse=True)
def add(self, line: str) -> None:
divider = "stats -- "
if divider in line:
header, data = line.split(divider)
cells = data[:-2].split(", ")
collected = [cell.split(": ") for cell in cells]
tag_and_category = header[:-2].split(" (")
if len(tag_and_category) == 2:
tag, category = tag_and_category
elif header[:3] == "ALL":
tag = "ALL"
category = header[4:-1]
elif header[:4] == "(ALL":
tag = "ALL"
category = header[5:-2]
else:
return
if len(tag) > 0:
for key, value in collected:
self._data[category][key][tag] = value
def is_empty(self) -> bool:
return len(self._data) == 0
def get_totals(self) -> List[Tuple[str, str]]:
totals = list(self._data[self._shared_heap_category]["total"].items())
TableStatistics.sort_by_value(totals)
return totals
def get_counts(self) -> List[Tuple[str, str]]:
counts = list(self._data[self._shared_heap_category]["samples"].items())
TableStatistics.sort_by_value(counts)
return counts
def _get_server_log(log_directory: Path) -> Path:
server_stderr_path = log_directory / "new_server" / "server.stderr"
if not server_stderr_path.is_file():
raise RuntimeError(f"Cannot find server output at `{server_stderr_path}`.")
return server_stderr_path
def _collect_memory_statistics_over_time(log_directory: Path) -> StatisticsOverTime:
server_log = _get_server_log(log_directory)
extracted = StatisticsOverTime()
with open(server_log) as server_log_file:
for line in server_log_file.readlines():
extracted.add(line)
return extracted
def _read_profiling_events(log_directory: Path) -> List[Event]:
profiling_output = backend_arguments.get_profiling_log_path(log_directory)
if not profiling_output.is_file():
raise RuntimeError(
f"Cannot find profiling output at `{profiling_output}`. "
+ "Please run Pyre with `--enable-profiling` or "
+ "`--enable-memory-profiling` option first."
)
return parse_events(profiling_output.read_text())
def to_traceevents(events: Sequence[Event]) -> List[Dict[str, Any]]:
def to_traceevent(event: Event) -> Optional[Dict[str, Any]]:
if isinstance(event, DurationEvent):
duration_us = event.duration
start_time_us = event.metadata.timestamp - duration_us
return {
"pid": event.metadata.worker_id,
"tid": event.metadata.pid,
"ts": start_time_us,
"ph": "X",
"name": event.metadata.name,
"dur": duration_us,
"args": event.metadata.tags,
}
elif isinstance(event, CounterEvent):
timestamp_us = event.metadata.timestamp
arguments: Dict[str, Any] = {
key: int(value) for key, value in event.metadata.tags.items()
}
return {
"pid": event.metadata.worker_id,
"tid": event.metadata.pid,
"ts": timestamp_us,
"ph": "C",
"name": event.metadata.name,
"args": arguments,
}
else:
return None
return [
trace_event
for trace_event in map(to_traceevent, events)
if trace_event is not None
]
def split_pre_and_post_initialization(
events: Sequence[Event],
) -> Tuple[Sequence[Event], Sequence[Event]]:
initialization_point = next(
(
index + 1
for index, event in enumerate(events)
if event.metadata.name == "initialization"
),
len(events),
)
return events[:initialization_point], events[initialization_point:]
def to_cold_start_phases(events: Sequence[Event]) -> Dict[str, int]:
result: Dict[str, int] = {}
pre_initialization_events, _ = split_pre_and_post_initialization(events)
for event in pre_initialization_events:
if not isinstance(event, DurationEvent):
continue
event.add_phase_duration_to_result(result)
if event.metadata.name == "initialization":
result["total"] = event.duration
return result
def to_incremental_updates(events: Sequence[Event]) -> List[Dict[str, int]]:
results: List[Dict[str, int]] = []
current: Dict[str, int] = {}
_, post_initialization_events = split_pre_and_post_initialization(events)
for event in post_initialization_events:
if not isinstance(event, DurationEvent):
continue
event.add_phase_duration_to_result(current)
if event.metadata.name == "incremental check":
current["total"] = event.duration
results.append(current)
current = {}
return results
def to_taint(events: Sequence[Event]) -> Dict[str, int]:
result: Dict[str, int] = {}
for event in events:
if not isinstance(event, DurationEvent):
continue
event.add_phase_duration_to_result(result)
fixpoint_events = [
event
for event in events
if isinstance(event, DurationEvent)
and event.metadata.tags.get(PHASE_NAME) == "Static analysis fixpoint"
]
if len(fixpoint_events) == 0:
return result
for name, value in fixpoint_events[-1].metadata.tags.items():
if name == PHASE_NAME:
continue
result[name.capitalize()] = int(value)
return result
def print_individual_table_sizes(
configuration: configuration_module.Configuration,
) -> None:
server_log = _get_server_log(Path(configuration.log_directory))
extracted = TableStatistics()
with open(str(server_log)) as server_log_file:
for line in server_log_file.readlines():
extracted.add(line)
if extracted.is_empty():
raise RuntimeError(
"Cannot find table size data in "
+ f"`{server_log.as_posix()}`. "
+ "Please run Pyre with `--debug` option first."
)
sizes = json.dumps(extracted.get_totals())
counts = json.dumps(extracted.get_counts())
# I manually put together this json in order to be
# simultaneously machine and human readable
combined = (
"{\n"
+ f' "total_table_sizes": {sizes},\n'
+ f' "table_key_counts": {counts}\n'
+ "}"
)
print(combined)
def print_total_shared_memory_size_over_time(
configuration: configuration_module.Configuration,
) -> None:
memory_over_time = _collect_memory_statistics_over_time(
Path(configuration.log_directory)
).to_json()
print(memory_over_time)
def print_total_shared_memory_size_over_time_graph(
configuration: configuration_module.Configuration,
) -> None:
statistics = _collect_memory_statistics_over_time(Path(configuration.log_directory))
statistics.graph_total_shared_memory_size_over_time()
def print_trace_event(
configuration: configuration_module.Configuration,
) -> None:
events = _read_profiling_events(Path(configuration.log_directory))
print(json.dumps(to_traceevents(events)))
def print_cold_start_phases(
configuration: configuration_module.Configuration,
) -> None:
events = _read_profiling_events(Path(configuration.log_directory))
print(json.dumps(to_cold_start_phases(events), indent=2))
def print_incremental_updates(
configuration: configuration_module.Configuration,
) -> None:
events = _read_profiling_events(Path(configuration.log_directory))
print(json.dumps(to_incremental_updates(events), indent=2))
def print_taint(
configuration: configuration_module.Configuration,
) -> None:
events = _read_profiling_events(Path(configuration.log_directory))
print(json.dumps(to_taint(events), indent=2))
def run_profile(
configuration: configuration_module.Configuration,
output: command_arguments.ProfileOutput,
) -> commands.ExitCode:
if output == command_arguments.ProfileOutput.INDIVIDUAL_TABLE_SIZES:
print_individual_table_sizes(configuration)
elif output == command_arguments.ProfileOutput.TOTAL_SHARED_MEMORY_SIZE_OVER_TIME:
print_total_shared_memory_size_over_time(configuration)
elif (
output
== command_arguments.ProfileOutput.TOTAL_SHARED_MEMORY_SIZE_OVER_TIME_GRAPH
):
print_total_shared_memory_size_over_time_graph(configuration)
elif output == command_arguments.ProfileOutput.TRACE_EVENT:
print_trace_event(configuration)
elif output == command_arguments.ProfileOutput.COLD_START_PHASES:
print_cold_start_phases(configuration)
elif output == command_arguments.ProfileOutput.INCREMENTAL_UPDATES:
print_incremental_updates(configuration)
elif output == command_arguments.ProfileOutput.TAINT:
print_taint(configuration)
else:
raise RuntimeError(f"Unrecognized output format: {output}")
return commands.ExitCode.SUCCESS
@remote_logging.log_usage(command_name="profile")
def run(
configuration: configuration_module.Configuration,
output: command_arguments.ProfileOutput,
) -> commands.ExitCode:
try:
return run_profile(configuration, output)
except Exception as error:
raise commands.ClientException(
f"Exception occurred during profile: {error}"
) from error