in main.py [0:0]
def get_Transformer(args, datasets):
"""Return the correct transformer."""
logger.info("Instantiating the transformer ...")
# Q used for sufficiency
Q_zy = partial(
MLP, **OmegaConf.to_container(args.model.Q_zy, resolve=True))
# Q used for minimality
Q_zx = partial(
MLP, **OmegaConf.to_container(args.model.Q_zx, resolve=True))
kwargs_loss = OmegaConf.to_container(args.model.loss, resolve=True)
kwargs_loss["Q"] = Q_zx
kwargs_trnsf = dict(Q=Q_zy)
Losses = dict(
VIBLoss=VIBLoss, ERMLoss=ERMLoss, DIBLossSklearn=DIBLossAlternLinearExact
)
is_linear = args.model.Q_zx.n_hidden_layers == 0
altern_minimax = args.model.loss.altern_minimax
kwargs = {}
if altern_minimax > 0:
if is_linear:
Losses["DIBLoss"] = DIBLossAlternLinear
else:
Losses["DIBLoss"] = (
DIBLossAlternHigher if args.model.loss.is_higher else DIBLossAltern
)
elif args.model.Loss == "DIBLoss":
# in the case where doing joint training you need to give the parameters of the criterion
# to the main (and only) optimizer
NeuralNetTransformer._get_params_for_optimizer = partialmethod(
_get_params_for_optimizer, is_add_criterion=True
)
Losses["DIBLoss"] = DIBLossLinear if is_linear else DIBLoss
kwargs["optimizer__param_groups"] = [
("Q_zx*", {"lr": args.train.kwargs.lr * args.train.lr_factor_zx})
]
return partial(
NeuralNetTransformer,
module=partial(
partial(IBEncoder, **kwargs_trnsf),
Encoder=partial(
get_img_encoder(args.encoder.name),
**OmegaConf.to_container(args.encoder.architecture, resolve=True),
),
**OmegaConf.to_container(args.model.architecture, resolve=True),
),
optimizer=get_optim(args),
criterion=partial(
Losses[args.model.Loss],
ZYCriterion=partial(
CrossEntropyLossGeneralize,
gamma=args.model.gamma_force_generalization,
map_target_position=datasets["train"].map_target_position,
),
**kwargs_loss,
),
callbacks__print_log__keys_ignored=args.keys_ignored,
**kwargs,
)