def build_attention_model()

in code/src/model/attention.py [0:0]


def build_attention_model(params, data, cuda=True):
    """
    Build a encoder / decoder, and the decoder reconstruction loss function.
    """
    # encoder / decoder / discriminator
    if params.transformer:
        encoder, decoder = build_transformer_enc_dec(params)
    else:
        encoder, decoder = build_lstm_enc_dec(params)
    if params.lambda_dis != "0":
        logger.info("============ Building attention model - Discriminator ...")
        if params.disc_lstm_dim > 0:
            assert params.disc_lstm_layers >= 1
            discriminator = MultiAttrDiscriminatorLSTM(params)
        else:
            discriminator = MultiAttrDiscriminator(params)
        logger.info("")
    else:
        discriminator = None

    # loss function for decoder reconstruction
    loss_weight = torch.FloatTensor(params.n_words).fill_(1)
    loss_weight[params.pad_index] = 0
    if params.label_smoothing <= 0:
        decoder.loss_fn = nn.CrossEntropyLoss(loss_weight, size_average=True)
    else:
        decoder.loss_fn = LabelSmoothedCrossEntropyLoss(
            params.label_smoothing,
            params.pad_index,
            size_average=True,
            weight=loss_weight,
        )

    # language model
    if params.lambda_lm != "0":
        logger.info("============ Building attention model - Language model ...")
        lm = LM(params, data['dico'])
        logger.info("")
    else:
        lm = None

    # cuda - models on CPU will be synchronized and don't need to be reloaded
    if cuda:
        encoder.cuda()
        decoder.cuda()
        if discriminator is not None:
            discriminator.cuda()
        if lm is not None:
            lm.cuda()

        # initialize the model with pretrained embeddings
        assert not (getattr(params, 'cpu_thread', False)) ^ (data is None)
        if data is not None:
            initialize_embeddings(encoder, decoder, params, data)

        # reload encoder / decoder / discriminator
        if params.reload_model != '':
            assert os.path.isfile(params.reload_model)
            logger.info("Reloading model from %s ..." % params.reload_model)
            reloaded = torch.load(params.reload_model)
            if params.reload_enc:
                logger.info("Reloading encoder...")
                enc = reloaded.get('enc', reloaded.get('encoder'))
                reload_model(encoder, enc, encoder.ENC_ATTR)
            if params.reload_dec:
                logger.info("Reloading decoder...")
                dec = reloaded.get('dec', reloaded.get('decoder'))
                reload_model(decoder, dec, decoder.DEC_ATTR)
            if params.reload_dis:
                assert discriminator is not None
                logger.info("Reloading discriminator...")
                dis = reloaded.get('dis', reloaded.get('discriminator'))
                reload_model(discriminator, dis, discriminator.DIS_ATTR)

    # log models
    encdec_params = set(
        p
        for module in [encoder, decoder]
        for p in module.parameters()
        if p.requires_grad
    )
    num_encdec_params = sum(p.numel() for p in encdec_params)
    logger.info("============ Model summary")
    logger.info("Number of enc+dec parameters: {}".format(num_encdec_params))
    logger.info("Encoder: {}".format(encoder))
    logger.info("Decoder: {}".format(decoder))
    logger.info("Discriminator: {}".format(discriminator))
    logger.info("LM: {}".format(lm))
    logger.info("")

    return encoder, decoder, discriminator, lm