in main.py [0:0]
def update_config_(args):
"""Update the configuration values based on other values."""
# increment the seed at each run
args.seed = args.seed + args.run
# multiply the number of examples by a factor size. Used to have number of examples depending
# on number of labels. Usually factor is 1.
args.datasize.n_examples = args.datasize.factor * args.datasize.n_examples
if args.datasize.n_examples_test == "train":
# use same number of train and test examples
args.datasize.n_examples_test = args.datasize.n_examples
if args.is_precompute_trnsf and args.train.trnsf_kwargs.is_train:
# if training transformer then paths need to agree
assert args.paths["trnsf_dirnames"][0] == args.paths["chckpnt_dirnames"][0]
# monitor training when you randomize the labels because validation does not mean anything
if args.dataset.kwargs.is_random_targets:
args.train.trnsf_kwargs.monitor_best = "train_loss_best"
args.train.clf_kwargs.monitor_best = "train_loss_best"
if not args.train.is_tensorboard:
args.paths["tensorboard_curr_dir"] = None
if args.experiment == "gap":
# dib with Q++
if args.model.name == "vib":
args.model.loss.beta = args.model.loss.beta * 40
elif args.model.name == "cdibL":
args.model.loss.beta = args.model.loss.beta / 100
elif args.model.name == "cdibS":
args.model.loss.beta = args.model.loss.beta * 30
if "dibL" in args.model.name:
# dib with Q++
args.model.Q_zx.hidden_size = args.model.Q_zy.hidden_size * 64
if "dibS" in args.model.name:
# dib with Q--
args.model.Q_zx.hidden_size = args.model.Q_zy.hidden_size // 64
if "dibXS" in args.model.name:
# dib with Q------
args.model.Q_zx.hidden_size = 1
if "dibXL" in args.model.name:
# dib with Q++++++++
args.model.Q_zx.hidden_size = 8192
short_long_monitor = dict(
vloss="valid_loss_best", tloss="train_loss_best", vacc="valid_acc_best"
)
# use short version for name of file
args.train.monitor_best = invert_dict(short_long_monitor).get(
args.train.monitor_best, args.train.monitor_best
)
hyperparam_path = hyperparam_to_path(args.hyperparameters)
args.paths.merge_with(
OmegaConf.create(
format_container(args.paths, dict(hyperparam_path=hyperparam_path))
)
)
# every change that should not modify the name of the file should go below this
# ----------------------------------------------------------------------------
# use long version in code
args.train.monitor_best = short_long_monitor.get(
args.train.monitor_best, args.train.monitor_best
)
args.train.trnsf_kwargs.monitor_best = short_long_monitor.get(
args.train.trnsf_kwargs.monitor_best, args.train.trnsf_kwargs.monitor_best
)
args.train.clf_kwargs.monitor_best = short_long_monitor.get(
args.train.clf_kwargs.monitor_best, args.train.clf_kwargs.monitor_best
)
if not args.is_precompute_trnsf:
logger.info("Not precomputing the transformer so setting train=False.")
args.train.trnsf_kwargs.is_train = False
args.train.kwargs.lr = args.train.lr_clf # ! DEV
else:
if args.model.name == "wdecayBob":
args.train.weight_decay = 1e-4
if args.model.name == "dropoutBob":
args.encoder.architecture.dropout = 0.5
if not args.datasize.is_valid_all_epochs and "train" in args.train.monitor_best:
# don't validate all epochs when validation >>> training and you only look at training
rm_valid_epochs_()
if args.model.is_joint:
args.model.gamma_force_generalization = 1
if "distractor" in args.clfs.name and not args.is_precompute_trnsf:
args.dataset.is_use_distractor = True
if "random" in args.clfs.name and not args.is_precompute_trnsf:
# if you want random dataset for classifier then make sure you are not randomizing for encoder
args.dataset.kwargs.is_random_targets = True
args.train.clf_kwargs.monitor_best = "train_loss_best" # don't monitor val
if isinstance(args.train.kwargs.lr, str) and "|" in args.train.kwargs.lr:
lr, lr_factor_zx = args.train.kwargs.lr.split("|")
args.train.kwargs.lr = float(lr)
args.train.lr_factor_zx = float(lr_factor_zx)
if args.model.name == "vibL":
# keep alice the same but increase bob view of alice
# vib with better approx of I[Z,Y] Q++
args.model.Q_zy.hidden_size = args.model.Q_zy.hidden_size * 16
if args.model.name == "wdecay":
args.train.weight_decay = 1e-4
if "correlation" in args.experiment:
if args.train.optim == "rmsprop":
if args.train.weight_decay == 0.0005:
args.train.weight_decay = 0.0003
elif args.train.optim == "sgd":
args.train.kwargs.lr = args.train.kwargs.lr * 50
if "perminvcdib" in args.model.name:
args.encoder.architecture.hidden_size = [1024]
args.model.architecture.z_dim = 1024
args.model.Q_zy.hidden_size = 256
args.model.Q_zy.n_hidden_layers = 1