profiler/hmm.py (60 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 import argparse import os import pickle import re import subprocess import sys from collections import defaultdict from os.path import join, abspath from numpy import median from pyro.util import timed EXAMPLES_DIR = join(abspath(__file__), os.pardir, os.pardir, "examples") def main(args): # Decide what experiments to run. configs = [] for model in args.model.split(","): for seed in args.seed.split(","): config = ["--seed={}".format(seed), "--model={}".format(model), "--num-steps={}".format(args.num_steps)] if args.cuda: config.append("--cuda") if args.jit: config.append("--jit") config.append("--time-compilation") configs.append(tuple(config)) # Run timing experiments serially. results = {} if os.path.exists(args.filename): try: with open(args.filename, "rb") as f: results = pickle.load(f) except Exception: pass for config in configs: with timed() as t: out = subprocess.check_output((sys.executable, "-O", abspath(join(EXAMPLES_DIR, "hmm.py"))) + config, encoding="utf-8") results[config] = t.elapsed if "--jit" in config: matched = re.search(r"time to compile: (\d+\.\d+)", out) if matched: compilation_time = float(matched.group(1)) results[config + ("(compilation time)",)] = compilation_time with open(args.filename, "wb") as f: pickle.dump(results, f) # Group by seed. grouped = defaultdict(list) for config, elapsed in results.items(): grouped[config[1:]].append(elapsed) # Print a table in github markdown format. print("| Min (sec) | Mean (sec) | Max (sec) | python -O examples/hmm.py ... |") print("| -: | -: | -: | - |") for config, times in sorted(grouped.items()): print("| {:0.1f} | {:0.1f} | {:0.1f} | {} |".format( min(times), median(times), max(times), " ".join(config))) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Profiler for examples/hmm.py") parser.add_argument("-f", "--filename", default="hmm_profile.pkl") parser.add_argument("-n", "--num-steps", default=50, type=int) parser.add_argument("-s", "--seed", default="0,1,2,3,4") parser.add_argument("-m", "--model", default="1,2,3,4,5,6,7") parser.add_argument("--cuda", action="store_true") parser.add_argument("--jit", action="store_true") args = parser.parse_args() main(args)