profiler/distributions.py (149 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
import argparse
import torch
from torch.autograd import Variable
from profiler.profiling_utils import Profile, profile_print
from pyro.distributions import (Bernoulli, Beta, Categorical, Cauchy, Dirichlet, Exponential, Gamma, LogNormal, Normal,
OneHotCategorical, Poisson, Uniform)
def T(arr):
return Variable(torch.DoubleTensor(arr))
TOOL = 'timeit'
TOOL_CFG = {}
DISTRIBUTIONS = {
'Bernoulli': (Bernoulli, {
'probs': T([0.3, 0.3, 0.3, 0.3])
}),
'Beta': (Beta, {
'concentration1': T([2.4, 2.4, 2.4, 2.4]),
'concentration0': T([3.2, 3.2, 3.2, 3.2])
}),
'Categorical': (Categorical, {
'probs': T([0.1, 0.3, 0.4, 0.2])
}),
'OneHotCategorical': (OneHotCategorical, {
'probs': T([0.1, 0.3, 0.4, 0.2])
}),
'Dirichlet': (Dirichlet, {
'concentration': T([2.4, 3, 6, 6])
}),
'Normal': (Normal, {
'loc': T([0.5, 0.5, 0.5, 0.5]),
'scale': T([1.2, 1.2, 1.2, 1.2])
}),
'LogNormal': (LogNormal, {
'loc': T([0.5, 0.5, 0.5, 0.5]),
'scale': T([1.2, 1.2, 1.2, 1.2])
}),
'Cauchy': (Cauchy, {
'loc': T([0.5, 0.5, 0.5, 0.5]),
'scale': T([1.2, 1.2, 1.2, 1.2])
}),
'Exponential': (Exponential, {
'rate': T([5.5, 3.2, 4.1, 5.6])
}),
'Poisson': (Poisson, {
'rate': T([5.5, 3.2, 4.1, 5.6])
}),
'Gamma': (Gamma, {
'concentration': T([2.4, 2.4, 2.4, 2.4]),
'rate': T([3.2, 3.2, 3.2, 3.2])
}),
'Uniform': (Uniform, {
'low': T([0, 0, 0, 0]),
'high': T([4, 4, 4, 4])
})
}
def get_tool():
return TOOL
def get_tool_cfg():
return TOOL_CFG
@Profile(
tool=get_tool,
tool_cfg=get_tool_cfg,
fn_id=lambda dist, batch_size, *args, **kwargs: 'sample_' + dist.dist_class.__name__ + '_N=' + str(batch_size))
def sample(dist, batch_size):
return dist.sample(sample_shape=(batch_size,))
@Profile(
tool=get_tool,
tool_cfg=get_tool_cfg,
fn_id=lambda dist, batch, *args, **kwargs: #
'log_prob_' + dist.dist_class.__name__ + '_N=' + str(batch.size()[0]))
def log_prob(dist, batch):
return dist.log_prob(batch)
def run_with_tool(tool, dists, batch_sizes):
column_widths, field_format, template = None, None, None
if tool == 'timeit':
profile_cols = 2 * len(batch_sizes)
column_widths = [14] * (profile_cols + 1)
field_format = [None] + ['{:.6f}'] * profile_cols
template = 'column'
elif tool == 'cprofile':
column_widths = [14, 80]
template = 'row'
with profile_print(column_widths, field_format, template) as out:
column_headers = []
for size in batch_sizes:
column_headers += ['SAMPLE (N=' + str(size) + ')', 'LOG_PROB (N=' + str(size) + ')']
out.header(['DISTRIBUTION'] + column_headers)
for dist_name in dists:
Dist, params = DISTRIBUTIONS[dist_name]
result_row = [dist_name]
dist = Dist(**params)
for size in batch_sizes:
sample_result, sample_prof = sample(dist, batch_size=size)
_, logpdf_prof = log_prob(dist, sample_result)
result_row += [sample_prof, logpdf_prof]
out.push(result_row)
def set_tool_cfg(args):
global TOOL, TOOL_CFG
TOOL = args.tool
tool_cfg = {}
if args.tool == 'timeit':
repeat = 5
if args.repeat is not None:
repeat = args.repeat
tool_cfg = {'repeat': repeat}
TOOL_CFG = tool_cfg
def main():
parser = argparse.ArgumentParser(description='Profiling distributions library using various' 'tools.')
parser.add_argument(
'--tool',
nargs='?',
default='timeit',
help='Profile using tool. One of following should be specified:'
' ["timeit", "cprofile"]')
parser.add_argument(
'--batch_sizes',
nargs='*',
type=int,
help='Batch size of tensor - max of 4 values allowed. '
'Default = [10000, 100000]')
parser.add_argument(
'--dists',
nargs='*',
type=str,
help='Run tests on distributions. One or more of following distributions '
'are supported: ["bernoulli, "beta", "categorical", "dirichlet", '
'"normal", "lognormal", "halfcauchy", "cauchy", "exponential", '
'"poisson", "one_hot_categorical", "gamma", "uniform"] '
'Default - Run profiling on all distributions')
parser.add_argument(
'--repeat',
nargs='?',
default=5,
type=int,
help='When profiling using "timeit", the number of repetitions to '
'use for the profiled function. default=5. The minimum value '
'is reported.')
args = parser.parse_args()
set_tool_cfg(args)
dists = args.dists
batch_sizes = args.batch_sizes
if not args.batch_sizes:
batch_sizes = [10000, 100000]
if len(batch_sizes) >= 4:
raise ValueError("Max of 4 batch sizes can be specified.")
if not dists:
dists = sorted(DISTRIBUTIONS.keys())
run_with_tool(args.tool, dists, batch_sizes)
if __name__ == '__main__':
main()