in train.py [0:0]
def main():
H, logprint = set_up_hyperparams()
H, data_train, data_valid_or_test, preprocess_fn = set_up_data(H)
vae, ema_vae = load_vaes(H, logprint)
if H.test_eval:
run_test_eval(H, ema_vae, data_valid_or_test, preprocess_fn, logprint)
else:
train_loop(H, data_train, data_valid_or_test, preprocess_fn, vae, ema_vae, logprint)