in examples/hmm.py [0:0]
def main(args):
if args.cuda:
torch.set_default_tensor_type('torch.cuda.FloatTensor')
logging.info('Loading data')
data = poly.load_data(poly.JSB_CHORALES)
logging.info('-' * 40)
model = models[args.model]
logging.info('Training {} on {} sequences'.format(
model.__name__, len(data['train']['sequences'])))
sequences = data['train']['sequences']
lengths = data['train']['sequence_lengths']
# find all the notes that are present at least once in the training set
present_notes = ((sequences == 1).sum(0).sum(0) > 0)
# remove notes that are never played (we remove 37/88 notes)
sequences = sequences[..., present_notes]
if args.truncate:
lengths = lengths.clamp(max=args.truncate)
sequences = sequences[:, :args.truncate]
num_observations = float(lengths.sum())
pyro.set_rng_seed(args.seed)
pyro.clear_param_store()
pyro.enable_validation(__debug__)
# We'll train using MAP Baum-Welch, i.e. MAP estimation while marginalizing
# out the hidden state x. This is accomplished via an automatic guide that
# learns point estimates of all of our conditional probability tables,
# named probs_*.
guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_")))
# To help debug our tensor shapes, let's print the shape of each site's
# distribution, value, and log_prob tensor. Note this information is
# automatically printed on most errors inside SVI.
if args.print_shapes:
first_available_dim = -2 if model is model_0 else -3
guide_trace = poutine.trace(guide).get_trace(
sequences, lengths, args=args, batch_size=args.batch_size)
model_trace = poutine.trace(
poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(
sequences, lengths, args=args, batch_size=args.batch_size)
logging.info(model_trace.format_shapes())
# Enumeration requires a TraceEnum elbo and declaring the max_plate_nesting.
# All of our models have two plates: "data" and "tones".
optim = Adam({'lr': args.learning_rate})
if args.tmc:
if args.jit:
raise NotImplementedError("jit support not yet added for TraceTMC_ELBO")
elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
tmc_model = poutine.infer_config(
model,
lambda msg: {"num_samples": args.tmc_num_samples, "expand": False} if msg["infer"].get("enumerate", None) == "parallel" else {}) # noqa: E501
svi = SVI(tmc_model, guide, optim, elbo)
else:
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = Elbo(max_plate_nesting=1 if model is model_0 else 2,
strict_enumeration_warning=(model is not model_7),
jit_options={"time_compilation": args.time_compilation})
svi = SVI(model, guide, optim, elbo)
# We'll train on small minibatches.
logging.info('Step\tLoss')
for step in range(args.num_steps):
loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size)
logging.info('{: >5d}\t{}'.format(step, loss / num_observations))
if args.jit and args.time_compilation:
logging.debug('time to compile: {} s.'.format(elbo._differentiable_loss.compile_time))
# We evaluate on the entire training dataset,
# excluding the prior term so our results are comparable across models.
train_loss = elbo.loss(model, guide, sequences, lengths, args, include_prior=False)
logging.info('training loss = {}'.format(train_loss / num_observations))
# Finally we evaluate on the test dataset.
logging.info('-' * 40)
logging.info('Evaluating on {} test sequences'.format(len(data['test']['sequences'])))
sequences = data['test']['sequences'][..., present_notes]
lengths = data['test']['sequence_lengths']
if args.truncate:
lengths = lengths.clamp(max=args.truncate)
num_observations = float(lengths.sum())
# note that since we removed unseen notes above (to make the problem a bit easier and for
# numerical stability) this test loss may not be directly comparable to numbers
# reported on this dataset elsewhere.
test_loss = elbo.loss(model, guide, sequences, lengths, args=args, include_prior=False)
logging.info('test loss = {}'.format(test_loss / num_observations))
# We expect models with higher capacity to perform better,
# but eventually overfit to the training set.
capacity = sum(value.reshape(-1).size(0)
for value in pyro.get_param_store().values())
logging.info('{} capacity = {} parameters'.format(model.__name__, capacity))