def build_model()

in XLM/src/model/__init__.py [0:0]


def build_model(params, dico):
    """
    Build model.
    """
    if params.encoder_only:
        # build
        model = TransformerModel(
            params, dico, is_encoder=True, with_output=True)

        # reload pretrained word embeddings
        if params.reload_emb != '':
            word2id, embeddings = load_embeddings(params.reload_emb, params)
            set_pretrain_emb(model, dico, word2id, embeddings)

        # reload a pretrained model
        if params.reload_model != '':
            logger.info("Reloading model from %s ..." % params.reload_model)
            reloaded = torch.load(params.reload_model, map_location=lambda storage, loc: storage.cuda(
                params.local_rank))['model']
            if all([k.startswith('module.') for k in reloaded.keys()]):
                reloaded = {k[len('module.'):]: v for k, v in reloaded.items()}

            # # HACK to reload models with less layers
            # for i in range(12, 24):
            #     for k in TRANSFORMER_LAYER_PARAMS:
            #         k = k % i
            #         if k in model.state_dict() and k not in reloaded:
            #             logger.warning("Parameter %s not found. Ignoring ..." % k)
            #             reloaded[k] = model.state_dict()[k]

            model.load_state_dict(reloaded)

        logger.info("Model: {}".format(model))
        logger.info("Number of parameters (model): %i" % sum(
            [p.numel() for p in model.parameters() if p.requires_grad]))

        return [model.cuda()]

    else:
        # build
        # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
        encoder = TransformerModel(
            params, dico, is_encoder=True, with_output=True)

        if params.separate_decoders:
            decoders = [TransformerModel(
                params, dico, is_encoder=False, with_output=True) for _ in params.lang2id.values()]
        else:
            decoders = [TransformerModel(
                params, dico, is_encoder=False, with_output=True)]

        for layer in range(params.n_layers_decoder):
            if layer <= params.n_share_dec - 1:
                assert params.amp == -1, "sharing layers is not supported with AMP"
                logger.info(
                    "Sharing decoder attention parameters for layer %i" % layer)
                for i in range(1, len(decoders)):
                    decoders[i].attentions[layer] = decoders[0].attentions[layer]

        # reload pretrained word embeddings
        if params.reload_emb != '':
            word2id, embeddings = load_embeddings(params.reload_emb, params)
            set_pretrain_emb(encoder, dico, word2id, embeddings)
            set_pretrain_emb(decoders, dico, word2id, embeddings)

        # reload a pretrained model
        if params.reload_model != '':
            enc_path, dec_path = params.reload_model.split(',')
            assert not (enc_path == '' and dec_path == '')

            # reload encoder
            if enc_path != '':
                logger.info("Reloading encoder from %s ..." % enc_path)
                enc_reload = torch.load(
                    enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
                enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
                if all([k.startswith('module.') for k in enc_reload.keys()]):
                    enc_reload = {k[len('module.'):]: v for k,
                                  v in enc_reload.items()}

                # # HACK to reload models trained with less languages
                n_langs = len(params.langs)
                n_langs_reload = enc_reload['lang_embeddings.weight'].size()[0]
                assert n_langs == n_langs_reload or n_langs == 2 * \
                    n_langs_reload or n_langs == 2 * n_langs_reload + 1
                if n_langs == 2 * n_langs_reload:
                    enc_reload['lang_embeddings.weight'] = enc_reload['lang_embeddings.weight'].transpose(
                        0, 1).repeat_interleave(2, 1).transpose(0, 1)
                elif n_langs == 2 * n_langs_reload + 1:
                    enc_reload['lang_embeddings.weight'] = enc_reload['lang_embeddings.weight'].transpose(
                        0, 1).repeat_interleave(2, 1).transpose(0, 1)
                    enc_reload['lang_embeddings.weight'] = torch.cat(
                        [enc_reload['lang_embeddings.weight'][0, :].unsqueeze(dim=0), enc_reload['lang_embeddings.weight']])

                if encoder.position_embeddings.weight.size()[0] == 2 * enc_reload['position_embeddings.weight'].size()[0]:
                    enc_reload['position_embeddings.weight'] = enc_reload['position_embeddings.weight'].repeat(
                        2, 1)

                encoder.load_state_dict(enc_reload)

            # reload decoders
            if dec_path != '':
                for dec in decoders:
                    logger.info("Reloading decoders from %s ..." % dec_path)
                    dec_reload = torch.load(
                        dec_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
                    dec_reload = dec_reload['model' if 'model' in dec_reload else 'decoder']
                    if all([k.startswith('module.') for k in dec_reload.keys()]):
                        dec_reload = {
                            k[len('module.'):]: v for k, v in dec_reload.items()}

                    # # HACK to reload models trained with less languages
                    n_langs = len(params.langs)
                    n_langs_reload = dec_reload['lang_embeddings.weight'].size()[
                        0]
                    assert n_langs == n_langs_reload or n_langs == 2 * \
                        n_langs_reload or n_langs == 2 * n_langs_reload + 1
                    if n_langs == 2 * n_langs_reload:
                        dec_reload['lang_embeddings.weight'] = dec_reload['lang_embeddings.weight'].transpose(
                            0, 1).repeat_interleave(2, 1).transpose(0, 1)
                    elif n_langs == 2 * n_langs_reload + 1:
                        dec_reload['lang_embeddings.weight'] = dec_reload['lang_embeddings.weight'].transpose(
                            0, 1).repeat_interleave(2, 1).transpose(0, 1)
                        dec_reload['lang_embeddings.weight'] = torch.cat(
                            [dec_reload['lang_embeddings.weight'][0, :].unsqueeze(dim=0), dec_reload['lang_embeddings.weight']])
                    if dec.position_embeddings.weight.size()[0] == 2 * dec_reload['position_embeddings.weight'].size()[0]:
                        dec_reload['position_embeddings.weight'] = dec_reload['position_embeddings.weight'].repeat(
                            2, 1)

                    for i in range(params.n_layers_decoder):
                        for name in DECODER_ONLY_PARAMS:
                            if name % i not in dec_reload:
                                logger.warning(
                                    "Parameter %s not found." % (name % i))
                                dec_reload[name % i] = dec.state_dict()[
                                    name % i]
                    dec.load_state_dict(dec_reload)

        logger.debug("Encoder: {}".format(encoder))
        logger.debug("Decoder: {}".format(decoders))
        logger.info("Number of parameters (encoder): %i" % sum(
            [p.numel() for p in encoder.parameters() if p.requires_grad]))
        logger.info("Number of parameters (decoders): %i" % sum(
            [p.numel() for p in decoders[0].parameters() if p.requires_grad]))
        logger.info(f"Number of decoders: {len(decoders)}")

        return [encoder.cuda()], [dec.cuda() for dec in decoders]