examples/mixed_hmm/experiment.py (117 lines of code) (raw):

# Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 import argparse import os import json import uuid import functools import torch import pyro import pyro.poutine as poutine from pyro.infer import TraceEnum_ELBO from model import model_generic, guide_generic from seal_data import prepare_seal def aic_num_parameters(model, guide=None): """ hacky AIC param count that includes all parameters in the model and guide """ def _size(tensor): """product of shape""" s = 1 for d in tensor.shape: s = s * d return s with poutine.block(), poutine.trace(param_only=True) as param_capture: TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss(model, guide) return sum(_size(node["value"]) for node in param_capture.trace.nodes.values()) def run_expt(args): data_dir = args["folder"] dataset = "seal" # args["dataset"] seed = args["seed"] optim = args["optim"] lr = args["learnrate"] timesteps = args["timesteps"] schedule = [] if not args["schedule"] else [int(i) for i in args["schedule"].split(",")] random_effects = {"group": args["group"], "individual": args["individual"]} pyro.enable_validation(args["validation"]) pyro.set_rng_seed(seed) # reproducible random effect parameter init filename = os.path.join(data_dir, "prep_seal_data.csv") config = prepare_seal(filename, random_effects) model = functools.partial(model_generic, config) # for JITing guide = functools.partial(guide_generic, config) # count the number of parameters once num_parameters = aic_num_parameters(model, guide) losses = [] # SGD if optim == "sgd": loss_fn = TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss with pyro.poutine.trace(param_only=True) as param_capture: loss_fn(model, guide) params = [site["value"].unconstrained() for site in param_capture.trace.nodes.values()] optimizer = torch.optim.Adam(params, lr=lr) if schedule: scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=schedule, gamma=0.5) schedule_step_loss = False else: scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') schedule_step_loss = True for t in range(timesteps): optimizer.zero_grad() loss = loss_fn(model, guide) loss.backward() optimizer.step() scheduler.step(loss.item() if schedule_step_loss else t) losses.append(loss.item()) print("Loss: {}, AIC[{}]: ".format(loss.item(), t), 2. * loss + 2. * num_parameters) # LBFGS elif optim == "lbfgs": loss_fn = TraceEnum_ELBO(max_plate_nesting=2).differentiable_loss with pyro.poutine.trace(param_only=True) as param_capture: loss_fn(model, guide) params = [site["value"].unconstrained() for site in param_capture.trace.nodes.values()] optimizer = torch.optim.LBFGS(params, lr=lr) if schedule: scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=schedule, gamma=0.5) schedule_step_loss = False else: schedule_step_loss = True scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') for t in range(timesteps): def closure(): optimizer.zero_grad() loss = loss_fn(model, guide) loss.backward() return loss loss = optimizer.step(closure) scheduler.step(loss.item() if schedule_step_loss else t) losses.append(loss.item()) print("Loss: {}, AIC[{}]: ".format(loss.item(), t), 2. * loss + 2. * num_parameters) else: raise ValueError("{} not supported optimizer".format(optim)) aic_final = 2. * losses[-1] + 2. * num_parameters print("AIC final: {}".format(aic_final)) results = {} results["args"] = args results["sizes"] = config["sizes"] results["likelihoods"] = losses results["likelihood_final"] = losses[-1] results["aic_final"] = aic_final results["aic_num_parameters"] = num_parameters if args["resultsdir"] is not None and os.path.exists(args["resultsdir"]): re_str = "g" + ("n" if args["group"] is None else "d" if args["group"] == "discrete" else "c") re_str += "i" + ("n" if args["individual"] is None else "d" if args["individual"] == "discrete" else "c") results_filename = "expt_{}_{}_{}.json".format(dataset, re_str, str(uuid.uuid4().hex)[0:5]) with open(os.path.join(args["resultsdir"], results_filename), "w") as f: json.dump(results, f) return results if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-g", "--group", default="none", type=str) parser.add_argument("-i", "--individual", default="none", type=str) parser.add_argument("-f", "--folder", default="./", type=str) parser.add_argument("-o", "--optim", default="sgd", type=str) parser.add_argument("-lr", "--learnrate", default=0.05, type=float) parser.add_argument("-t", "--timesteps", default=1000, type=int) parser.add_argument("-r", "--resultsdir", default="./results", type=str) parser.add_argument("-s", "--seed", default=101, type=int) parser.add_argument("--schedule", default="", type=str) parser.add_argument('--validation', action='store_true') args = parser.parse_args() if args.group == "none": args.group = None if args.individual == "none": args.individual = None run_expt(vars(args))