in run_experiments_real.py [0:0]
def get_experiment_function(args):
experiments = {
"run_autoencoder_shapes": run_autoencoder_shapes,
"run_autoencoder_mnist": run_autoencoder_mnist,
"run_cci_vae_shapes": run_cci_vae_shapes,
"run_cci_vae_mnist": run_cci_vae_mnist,
"run_cci_vae_single_digit_mnist": run_cci_vae_mnist,
}
experiment = experiments[f"run_{args.model}_{args.data}"]
print(f"run_{args.model}_{args.data}")
if args.data == "shapes":
experiment = partial(experiment, n_classes=args.n_classes)
elif args.data in {"mnist", "single_digit_mnist"}:
experiment = partial(experiment, proportion=args.mnist_proportion)
else:
raise ValueError(f"dataset {args.data} not supported")
# standard autoencoder
if "autoencoder" == args.model and args.no_latent_op:
experiment = partial(experiment, use_latent_op=False)
n_rotations, n_x_translations, n_y_translations = get_n_transformations(args)
experiment = partial(
experiment,
n_rotations=n_rotations,
n_x_translations=n_x_translations,
n_y_translations=n_y_translations,
architecture=args.architecture,
z_dim=args.z_dim,
distribution=args.distribution,
)
return experiment