def get_Transformer()

in main.py [0:0]


def get_Transformer(args, datasets):
    """Return the correct transformer."""
    logger.info("Instantiating the transformer ...")

    # Q used for sufficiency
    Q_zy = partial(
        MLP, **OmegaConf.to_container(args.model.Q_zy, resolve=True))

    # Q used for minimality
    Q_zx = partial(
        MLP, **OmegaConf.to_container(args.model.Q_zx, resolve=True))

    kwargs_loss = OmegaConf.to_container(args.model.loss, resolve=True)
    kwargs_loss["Q"] = Q_zx

    kwargs_trnsf = dict(Q=Q_zy)

    Losses = dict(
        VIBLoss=VIBLoss, ERMLoss=ERMLoss, DIBLossSklearn=DIBLossAlternLinearExact
    )

    is_linear = args.model.Q_zx.n_hidden_layers == 0
    altern_minimax = args.model.loss.altern_minimax
    kwargs = {}
    if altern_minimax > 0:
        if is_linear:
            Losses["DIBLoss"] = DIBLossAlternLinear
        else:
            Losses["DIBLoss"] = (
                DIBLossAlternHigher if args.model.loss.is_higher else DIBLossAltern
            )

    elif args.model.Loss == "DIBLoss":
        # in the case where doing joint training you need to give the parameters of the criterion
        # to the main (and only) optimizer
        NeuralNetTransformer._get_params_for_optimizer = partialmethod(
            _get_params_for_optimizer, is_add_criterion=True
        )
        Losses["DIBLoss"] = DIBLossLinear if is_linear else DIBLoss
        kwargs["optimizer__param_groups"] = [
            ("Q_zx*", {"lr": args.train.kwargs.lr * args.train.lr_factor_zx})
        ]

    return partial(
        NeuralNetTransformer,
        module=partial(
            partial(IBEncoder, **kwargs_trnsf),
            Encoder=partial(
                get_img_encoder(args.encoder.name),
                **OmegaConf.to_container(args.encoder.architecture, resolve=True),
            ),
            **OmegaConf.to_container(args.model.architecture, resolve=True),
        ),
        optimizer=get_optim(args),
        criterion=partial(
            Losses[args.model.Loss],
            ZYCriterion=partial(
                CrossEntropyLossGeneralize,
                gamma=args.model.gamma_force_generalization,
                map_target_position=datasets["train"].map_target_position,
            ),
            **kwargs_loss,
        ),
        callbacks__print_log__keys_ignored=args.keys_ignored,
        **kwargs,
    )