def update_config_()

in main.py [0:0]


def update_config_(args):
    """Update the configuration values based on other values."""

    # increment the seed at each run
    args.seed = args.seed + args.run

    # multiply the number of examples by a factor size. Used to have number of examples depending
    # on number of labels. Usually factor is 1.
    args.datasize.n_examples = args.datasize.factor * args.datasize.n_examples

    if args.datasize.n_examples_test == "train":
        # use same number of train and test examples
        args.datasize.n_examples_test = args.datasize.n_examples

    if args.is_precompute_trnsf and args.train.trnsf_kwargs.is_train:
        # if training transformer then paths need to agree
        assert args.paths["trnsf_dirnames"][0] == args.paths["chckpnt_dirnames"][0]

    # monitor training when you randomize the labels because validation does not mean anything
    if args.dataset.kwargs.is_random_targets:
        args.train.trnsf_kwargs.monitor_best = "train_loss_best"
        args.train.clf_kwargs.monitor_best = "train_loss_best"

    if not args.train.is_tensorboard:
        args.paths["tensorboard_curr_dir"] = None

    if args.experiment == "gap":
        # dib with Q++
        if args.model.name == "vib":
            args.model.loss.beta = args.model.loss.beta * 40

        elif args.model.name == "cdibL":
            args.model.loss.beta = args.model.loss.beta / 100

        elif args.model.name == "cdibS":
            args.model.loss.beta = args.model.loss.beta * 30

    if "dibL" in args.model.name:
        # dib with Q++
        args.model.Q_zx.hidden_size = args.model.Q_zy.hidden_size * 64

    if "dibS" in args.model.name:
        # dib with Q--
        args.model.Q_zx.hidden_size = args.model.Q_zy.hidden_size // 64

    if "dibXS" in args.model.name:
        # dib with Q------
        args.model.Q_zx.hidden_size = 1

    if "dibXL" in args.model.name:
        # dib with Q++++++++
        args.model.Q_zx.hidden_size = 8192

    short_long_monitor = dict(
        vloss="valid_loss_best", tloss="train_loss_best", vacc="valid_acc_best"
    )

    # use short version for name of file
    args.train.monitor_best = invert_dict(short_long_monitor).get(
        args.train.monitor_best, args.train.monitor_best
    )

    hyperparam_path = hyperparam_to_path(args.hyperparameters)
    args.paths.merge_with(
        OmegaConf.create(
            format_container(args.paths, dict(hyperparam_path=hyperparam_path))
        )
    )
    # every change that should not modify the name of the file should go below this
    # ----------------------------------------------------------------------------

    # use long version in code
    args.train.monitor_best = short_long_monitor.get(
        args.train.monitor_best, args.train.monitor_best
    )
    args.train.trnsf_kwargs.monitor_best = short_long_monitor.get(
        args.train.trnsf_kwargs.monitor_best, args.train.trnsf_kwargs.monitor_best
    )
    args.train.clf_kwargs.monitor_best = short_long_monitor.get(
        args.train.clf_kwargs.monitor_best, args.train.clf_kwargs.monitor_best
    )

    if not args.is_precompute_trnsf:
        logger.info("Not precomputing the transformer so setting train=False.")
        args.train.trnsf_kwargs.is_train = False
        args.train.kwargs.lr = args.train.lr_clf  # ! DEV
    else:
        if args.model.name == "wdecayBob":
            args.train.weight_decay = 1e-4

        if args.model.name == "dropoutBob":
            args.encoder.architecture.dropout = 0.5

    if not args.datasize.is_valid_all_epochs and "train" in args.train.monitor_best:
        # don't validate all epochs when validation >>> training and you only look at training
        rm_valid_epochs_()

    if args.model.is_joint:
        args.model.gamma_force_generalization = 1

    if "distractor" in args.clfs.name and not args.is_precompute_trnsf:
        args.dataset.is_use_distractor = True

    if "random" in args.clfs.name and not args.is_precompute_trnsf:
        # if you want random dataset for classifier then make sure you are not randomizing for encoder
        args.dataset.kwargs.is_random_targets = True
        args.train.clf_kwargs.monitor_best = "train_loss_best"  # don't monitor val

    if isinstance(args.train.kwargs.lr, str) and "|" in args.train.kwargs.lr:
        lr, lr_factor_zx = args.train.kwargs.lr.split("|")
        args.train.kwargs.lr = float(lr)
        args.train.lr_factor_zx = float(lr_factor_zx)

    if args.model.name == "vibL":
        # keep alice the same but increase bob view of alice
        # vib with better approx of I[Z,Y] Q++
        args.model.Q_zy.hidden_size = args.model.Q_zy.hidden_size * 16

    if args.model.name == "wdecay":
        args.train.weight_decay = 1e-4

    if "correlation" in args.experiment:
        if args.train.optim == "rmsprop":
            if args.train.weight_decay == 0.0005:
                args.train.weight_decay = 0.0003

        elif args.train.optim == "sgd":
            args.train.kwargs.lr = args.train.kwargs.lr * 50

    if "perminvcdib" in args.model.name:
        args.encoder.architecture.hidden_size = [1024]
        args.model.architecture.z_dim = 1024
        args.model.Q_zy.hidden_size = 256
        args.model.Q_zy.n_hidden_layers = 1