profiler/distributions.py (149 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 import argparse import torch from torch.autograd import Variable from profiler.profiling_utils import Profile, profile_print from pyro.distributions import (Bernoulli, Beta, Categorical, Cauchy, Dirichlet, Exponential, Gamma, LogNormal, Normal, OneHotCategorical, Poisson, Uniform) def T(arr): return Variable(torch.DoubleTensor(arr)) TOOL = 'timeit' TOOL_CFG = {} DISTRIBUTIONS = { 'Bernoulli': (Bernoulli, { 'probs': T([0.3, 0.3, 0.3, 0.3]) }), 'Beta': (Beta, { 'concentration1': T([2.4, 2.4, 2.4, 2.4]), 'concentration0': T([3.2, 3.2, 3.2, 3.2]) }), 'Categorical': (Categorical, { 'probs': T([0.1, 0.3, 0.4, 0.2]) }), 'OneHotCategorical': (OneHotCategorical, { 'probs': T([0.1, 0.3, 0.4, 0.2]) }), 'Dirichlet': (Dirichlet, { 'concentration': T([2.4, 3, 6, 6]) }), 'Normal': (Normal, { 'loc': T([0.5, 0.5, 0.5, 0.5]), 'scale': T([1.2, 1.2, 1.2, 1.2]) }), 'LogNormal': (LogNormal, { 'loc': T([0.5, 0.5, 0.5, 0.5]), 'scale': T([1.2, 1.2, 1.2, 1.2]) }), 'Cauchy': (Cauchy, { 'loc': T([0.5, 0.5, 0.5, 0.5]), 'scale': T([1.2, 1.2, 1.2, 1.2]) }), 'Exponential': (Exponential, { 'rate': T([5.5, 3.2, 4.1, 5.6]) }), 'Poisson': (Poisson, { 'rate': T([5.5, 3.2, 4.1, 5.6]) }), 'Gamma': (Gamma, { 'concentration': T([2.4, 2.4, 2.4, 2.4]), 'rate': T([3.2, 3.2, 3.2, 3.2]) }), 'Uniform': (Uniform, { 'low': T([0, 0, 0, 0]), 'high': T([4, 4, 4, 4]) }) } def get_tool(): return TOOL def get_tool_cfg(): return TOOL_CFG @Profile( tool=get_tool, tool_cfg=get_tool_cfg, fn_id=lambda dist, batch_size, *args, **kwargs: 'sample_' + dist.dist_class.__name__ + '_N=' + str(batch_size)) def sample(dist, batch_size): return dist.sample(sample_shape=(batch_size,)) @Profile( tool=get_tool, tool_cfg=get_tool_cfg, fn_id=lambda dist, batch, *args, **kwargs: # 'log_prob_' + dist.dist_class.__name__ + '_N=' + str(batch.size()[0])) def log_prob(dist, batch): return dist.log_prob(batch) def run_with_tool(tool, dists, batch_sizes): column_widths, field_format, template = None, None, None if tool == 'timeit': profile_cols = 2 * len(batch_sizes) column_widths = [14] * (profile_cols + 1) field_format = [None] + ['{:.6f}'] * profile_cols template = 'column' elif tool == 'cprofile': column_widths = [14, 80] template = 'row' with profile_print(column_widths, field_format, template) as out: column_headers = [] for size in batch_sizes: column_headers += ['SAMPLE (N=' + str(size) + ')', 'LOG_PROB (N=' + str(size) + ')'] out.header(['DISTRIBUTION'] + column_headers) for dist_name in dists: Dist, params = DISTRIBUTIONS[dist_name] result_row = [dist_name] dist = Dist(**params) for size in batch_sizes: sample_result, sample_prof = sample(dist, batch_size=size) _, logpdf_prof = log_prob(dist, sample_result) result_row += [sample_prof, logpdf_prof] out.push(result_row) def set_tool_cfg(args): global TOOL, TOOL_CFG TOOL = args.tool tool_cfg = {} if args.tool == 'timeit': repeat = 5 if args.repeat is not None: repeat = args.repeat tool_cfg = {'repeat': repeat} TOOL_CFG = tool_cfg def main(): parser = argparse.ArgumentParser(description='Profiling distributions library using various' 'tools.') parser.add_argument( '--tool', nargs='?', default='timeit', help='Profile using tool. One of following should be specified:' ' ["timeit", "cprofile"]') parser.add_argument( '--batch_sizes', nargs='*', type=int, help='Batch size of tensor - max of 4 values allowed. ' 'Default = [10000, 100000]') parser.add_argument( '--dists', nargs='*', type=str, help='Run tests on distributions. One or more of following distributions ' 'are supported: ["bernoulli, "beta", "categorical", "dirichlet", ' '"normal", "lognormal", "halfcauchy", "cauchy", "exponential", ' '"poisson", "one_hot_categorical", "gamma", "uniform"] ' 'Default - Run profiling on all distributions') parser.add_argument( '--repeat', nargs='?', default=5, type=int, help='When profiling using "timeit", the number of repetitions to ' 'use for the profiled function. default=5. The minimum value ' 'is reported.') args = parser.parse_args() set_tool_cfg(args) dists = args.dists batch_sizes = args.batch_sizes if not args.batch_sizes: batch_sizes = [10000, 100000] if len(batch_sizes) >= 4: raise ValueError("Max of 4 batch sizes can be specified.") if not dists: dists = sorted(DISTRIBUTIONS.keys()) run_with_tool(args.tool, dists, batch_sizes) if __name__ == '__main__': main()