in examples/dmm/dmm.py [0:0]
def main(args):
# setup logging
log = get_logger(args.log)
log(args)
data = poly.load_data(poly.JSB_CHORALES)
training_seq_lengths = data['train']['sequence_lengths']
training_data_sequences = data['train']['sequences']
test_seq_lengths = data['test']['sequence_lengths']
test_data_sequences = data['test']['sequences']
val_seq_lengths = data['valid']['sequence_lengths']
val_data_sequences = data['valid']['sequences']
N_train_data = len(training_seq_lengths)
N_train_time_slices = float(torch.sum(training_seq_lengths))
N_mini_batches = int(N_train_data / args.mini_batch_size +
int(N_train_data % args.mini_batch_size > 0))
log("N_train_data: %d avg. training seq. length: %.2f N_mini_batches: %d" %
(N_train_data, training_seq_lengths.float().mean(), N_mini_batches))
# how often we do validation/test evaluation during training
val_test_frequency = 50
# the number of samples we use to do the evaluation
n_eval_samples = 1
# package repeated copies of val/test data for faster evaluation
# (i.e. set us up for vectorization)
def rep(x):
rep_shape = torch.Size([x.size(0) * n_eval_samples]) + x.size()[1:]
repeat_dims = [1] * len(x.size())
repeat_dims[0] = n_eval_samples
return x.repeat(repeat_dims).reshape(n_eval_samples, -1).transpose(1, 0).reshape(rep_shape)
# get the validation/test data ready for the dmm: pack into sequences, etc.
val_seq_lengths = rep(val_seq_lengths)
test_seq_lengths = rep(test_seq_lengths)
val_batch, val_batch_reversed, val_batch_mask, val_seq_lengths = poly.get_mini_batch(
torch.arange(n_eval_samples * val_data_sequences.shape[0]), rep(val_data_sequences),
val_seq_lengths, cuda=args.cuda)
test_batch, test_batch_reversed, test_batch_mask, test_seq_lengths = poly.get_mini_batch(
torch.arange(n_eval_samples * test_data_sequences.shape[0]), rep(test_data_sequences),
test_seq_lengths, cuda=args.cuda)
# instantiate the dmm
dmm = DMM(rnn_dropout_rate=args.rnn_dropout_rate, num_iafs=args.num_iafs,
iaf_dim=args.iaf_dim, use_cuda=args.cuda)
# setup optimizer
adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
"clip_norm": args.clip_norm, "lrd": args.lr_decay,
"weight_decay": args.weight_decay}
adam = ClippedAdam(adam_params)
# setup inference algorithm
if args.tmc:
if args.jit:
raise NotImplementedError("no JIT support yet for TMC")
tmc_loss = TraceTMC_ELBO()
dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
svi = SVI(dmm.model, dmm_guide, adam, loss=tmc_loss)
elif args.tmcelbo:
if args.jit:
raise NotImplementedError("no JIT support yet for TMC ELBO")
elbo = TraceEnum_ELBO()
dmm_guide = config_enumerate(dmm.guide, default="parallel", num_samples=args.tmc_num_samples, expand=False)
svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
else:
elbo = JitTrace_ELBO() if args.jit else Trace_ELBO()
svi = SVI(dmm.model, dmm.guide, adam, loss=elbo)
# now we're going to define some functions we need to form the main training loop
# saves the model and optimizer states to disk
def save_checkpoint():
log("saving model to %s..." % args.save_model)
torch.save(dmm.state_dict(), args.save_model)
log("saving optimizer states to %s..." % args.save_opt)
adam.save(args.save_opt)
log("done saving model and optimizer checkpoints to disk.")
# loads the model and optimizer states from disk
def load_checkpoint():
assert exists(args.load_opt) and exists(args.load_model), \
"--load-model and/or --load-opt misspecified"
log("loading model from %s..." % args.load_model)
dmm.load_state_dict(torch.load(args.load_model))
log("loading optimizer states from %s..." % args.load_opt)
adam.load(args.load_opt)
log("done loading model and optimizer states.")
# prepare a mini-batch and take a gradient step to minimize -elbo
def process_minibatch(epoch, which_mini_batch, shuffled_indices):
if args.annealing_epochs > 0 and epoch < args.annealing_epochs:
# compute the KL annealing factor approriate for the current mini-batch in the current epoch
min_af = args.minimum_annealing_factor
annealing_factor = min_af + (1.0 - min_af) * \
(float(which_mini_batch + epoch * N_mini_batches + 1) /
float(args.annealing_epochs * N_mini_batches))
else:
# by default the KL annealing factor is unity
annealing_factor = 1.0
# compute which sequences in the training set we should grab
mini_batch_start = (which_mini_batch * args.mini_batch_size)
mini_batch_end = np.min([(which_mini_batch + 1) * args.mini_batch_size, N_train_data])
mini_batch_indices = shuffled_indices[mini_batch_start:mini_batch_end]
# grab a fully prepped mini-batch using the helper function in the data loader
mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
= poly.get_mini_batch(mini_batch_indices, training_data_sequences,
training_seq_lengths, cuda=args.cuda)
# do an actual gradient step
loss = svi.step(mini_batch, mini_batch_reversed, mini_batch_mask,
mini_batch_seq_lengths, annealing_factor)
# keep track of the training loss
return loss
# helper function for doing evaluation
def do_evaluation():
# put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
dmm.rnn.eval()
# compute the validation and test loss n_samples many times
val_nll = svi.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
val_seq_lengths) / float(torch.sum(val_seq_lengths))
test_nll = svi.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
test_seq_lengths) / float(torch.sum(test_seq_lengths))
# put the RNN back into training mode (i.e. turn on drop-out if applicable)
dmm.rnn.train()
return val_nll, test_nll
# if checkpoint files provided, load model and optimizer states from disk before we start training
if args.load_opt != '' and args.load_model != '':
load_checkpoint()
#################
# TRAINING LOOP #
#################
times = [time.time()]
for epoch in range(args.num_epochs):
# if specified, save model and optimizer states to disk every checkpoint_freq epochs
if args.checkpoint_freq > 0 and epoch > 0 and epoch % args.checkpoint_freq == 0:
save_checkpoint()
# accumulator for our estimate of the negative log likelihood (or rather -elbo) for this epoch
epoch_nll = 0.0
# prepare mini-batch subsampling indices for this epoch
shuffled_indices = torch.randperm(N_train_data)
# process each mini-batch; this is where we take gradient steps
for which_mini_batch in range(N_mini_batches):
epoch_nll += process_minibatch(epoch, which_mini_batch, shuffled_indices)
# report training diagnostics
times.append(time.time())
epoch_time = times[-1] - times[-2]
log("[training epoch %04d] %.4f \t\t\t\t(dt = %.3f sec)" %
(epoch, epoch_nll / N_train_time_slices, epoch_time))
# do evaluation on test and validation data and report results
if val_test_frequency > 0 and epoch > 0 and epoch % val_test_frequency == 0:
val_nll, test_nll = do_evaluation()
log("[val/test epoch %04d] %.4f %.4f" % (epoch, val_nll, test_nll))