profiler/profiling_utils.py (95 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import cProfile
from io import StringIO
import functools
import os
import pstats
import timeit
from contextlib import contextmanager
from prettytable import ALL, PrettyTable
FILE = os.path.abspath(__file__)
PROF_DIR = os.path.join(os.path.dirname(FILE), 'data')
if not os.path.exists(PROF_DIR):
os.makedirs(PROF_DIR)
class ProfilePrinter:
def __init__(self, column_widths=None, field_format=None, template='column'):
assert template in ('column', 'row')
self._template = template
self._column_widths = column_widths
self._field_format = field_format
self._header = None
if template == 'column':
self.table = PrettyTable(header=False, hrules=ALL)
else:
self.table = PrettyTable(header=False, hrules=ALL)
def _formatted_values(self, values):
if self._field_format is not None:
assert len(self._field_format) == len(values)
return [f.format(val) if f else str(val) for f, val in zip(self._field_format, values)]
return values
def _add_using_row_format(self, values):
assert len(self._header) == len(values)
formatted_vals = self._formatted_values(values)
for i in range(len(self._header)):
self.table.add_row([self._header[i], formatted_vals[i]])
def _add_using_column_format(self, values):
formatted_vals = self._formatted_values(values)
self.table.add_row(formatted_vals)
def push(self, values):
if self._template == 'column':
self._add_using_column_format(values)
else:
self._add_using_row_format(values)
def header(self, values):
self._header = values
if self._template == 'column':
field_names = values
self.table.add_row(values)
else:
field_names = ['KEY', 'VALUE']
self.table.field_names = field_names
for i in range(len(field_names)):
self.table.align[field_names[i]] = 'l'
if self._column_widths:
self.table.max_width[field_names[i]] = self._column_widths[i]
def print(self):
print(self.table)
@contextmanager
def profile_print(column_widths=None, field_format=None, template='column'):
out_buffer = ProfilePrinter(column_widths, field_format, template)
try:
yield out_buffer
finally:
out_buffer.print()
def profile_timeit(fn_callable, repeat=1):
ret = fn_callable()
return ret, min(timeit.repeat(fn_callable, repeat=repeat, number=1))
def profile_cprofile(fn_callable, prof_file):
prof = cProfile.Profile()
ret = prof.runcall(fn_callable)
prof.dump_stats(prof_file)
prof_stats = StringIO()
p = pstats.Stats(prof_file, stream=prof_stats)
p.strip_dirs().sort_stats('cumulative').print_stats(0.5)
return ret, prof_stats.getvalue()
class Profile:
def __init__(self, tool, tool_cfg, fn_id):
self.tool = tool
self.tool_cfg = tool_cfg
self.fn_id = fn_id
def _set_decorator_params(self):
if callable(self.tool):
self.tool = self.tool()
if callable(self.tool_cfg):
self.tool_cfg = self.tool_cfg()
def __call__(self, fn):
def wrapped_fn(*args, **kwargs):
self._set_decorator_params()
fn_callable = functools.partial(fn, *args, **kwargs)
if self.tool == 'timeit':
return profile_timeit(fn_callable, **self.tool_cfg)
elif self.tool == 'cprofile':
prof_file = os.path.join(PROF_DIR, self.fn_id(*args, **kwargs))
return profile_cprofile(fn_callable, prof_file=prof_file)
else:
raise ValueError('Invalid profiling tool specified: {}.'.format(self.tool))
return wrapped_fn