profiler/profiling_utils.py (95 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 import cProfile from io import StringIO import functools import os import pstats import timeit from contextlib import contextmanager from prettytable import ALL, PrettyTable FILE = os.path.abspath(__file__) PROF_DIR = os.path.join(os.path.dirname(FILE), 'data') if not os.path.exists(PROF_DIR): os.makedirs(PROF_DIR) class ProfilePrinter: def __init__(self, column_widths=None, field_format=None, template='column'): assert template in ('column', 'row') self._template = template self._column_widths = column_widths self._field_format = field_format self._header = None if template == 'column': self.table = PrettyTable(header=False, hrules=ALL) else: self.table = PrettyTable(header=False, hrules=ALL) def _formatted_values(self, values): if self._field_format is not None: assert len(self._field_format) == len(values) return [f.format(val) if f else str(val) for f, val in zip(self._field_format, values)] return values def _add_using_row_format(self, values): assert len(self._header) == len(values) formatted_vals = self._formatted_values(values) for i in range(len(self._header)): self.table.add_row([self._header[i], formatted_vals[i]]) def _add_using_column_format(self, values): formatted_vals = self._formatted_values(values) self.table.add_row(formatted_vals) def push(self, values): if self._template == 'column': self._add_using_column_format(values) else: self._add_using_row_format(values) def header(self, values): self._header = values if self._template == 'column': field_names = values self.table.add_row(values) else: field_names = ['KEY', 'VALUE'] self.table.field_names = field_names for i in range(len(field_names)): self.table.align[field_names[i]] = 'l' if self._column_widths: self.table.max_width[field_names[i]] = self._column_widths[i] def print(self): print(self.table) @contextmanager def profile_print(column_widths=None, field_format=None, template='column'): out_buffer = ProfilePrinter(column_widths, field_format, template) try: yield out_buffer finally: out_buffer.print() def profile_timeit(fn_callable, repeat=1): ret = fn_callable() return ret, min(timeit.repeat(fn_callable, repeat=repeat, number=1)) def profile_cprofile(fn_callable, prof_file): prof = cProfile.Profile() ret = prof.runcall(fn_callable) prof.dump_stats(prof_file) prof_stats = StringIO() p = pstats.Stats(prof_file, stream=prof_stats) p.strip_dirs().sort_stats('cumulative').print_stats(0.5) return ret, prof_stats.getvalue() class Profile: def __init__(self, tool, tool_cfg, fn_id): self.tool = tool self.tool_cfg = tool_cfg self.fn_id = fn_id def _set_decorator_params(self): if callable(self.tool): self.tool = self.tool() if callable(self.tool_cfg): self.tool_cfg = self.tool_cfg() def __call__(self, fn): def wrapped_fn(*args, **kwargs): self._set_decorator_params() fn_callable = functools.partial(fn, *args, **kwargs) if self.tool == 'timeit': return profile_timeit(fn_callable, **self.tool_cfg) elif self.tool == 'cprofile': prof_file = os.path.join(PROF_DIR, self.fn_id(*args, **kwargs)) return profile_cprofile(fn_callable, prof_file=prof_file) else: raise ValueError('Invalid profiling tool specified: {}.'.format(self.tool)) return wrapped_fn