def get_experiment_function()

in run_experiments_real.py [0:0]


def get_experiment_function(args):
    experiments = {
        "run_autoencoder_shapes": run_autoencoder_shapes,
        "run_autoencoder_mnist": run_autoencoder_mnist,
        "run_cci_vae_shapes": run_cci_vae_shapes,
        "run_cci_vae_mnist": run_cci_vae_mnist,
        "run_cci_vae_single_digit_mnist": run_cci_vae_mnist,
    }
    experiment = experiments[f"run_{args.model}_{args.data}"]
    print(f"run_{args.model}_{args.data}")

    if args.data == "shapes":
        experiment = partial(experiment, n_classes=args.n_classes)
    elif args.data in {"mnist", "single_digit_mnist"}:
        experiment = partial(experiment, proportion=args.mnist_proportion)
    else:
        raise ValueError(f"dataset {args.data} not supported")

    # standard autoencoder
    if "autoencoder" == args.model and args.no_latent_op:
        experiment = partial(experiment, use_latent_op=False)

    n_rotations, n_x_translations, n_y_translations = get_n_transformations(args)

    experiment = partial(
        experiment,
        n_rotations=n_rotations,
        n_x_translations=n_x_translations,
        n_y_translations=n_y_translations,
        architecture=args.architecture,
        z_dim=args.z_dim,
        distribution=args.distribution,
    )
    return experiment