def build_seq2seq_model()

in code/src/model/seq2seq.py [0:0]


def build_seq2seq_model(params, data, cuda=True):
    """
    Build a encoder / decoder, and the decoder reconstruction loss function.
    """
    # encoder / decoder / discriminator
    logger.info("============ Building seq2seq model - Encoder ...")
    encoder = Encoder(params)
    logger.info("")
    logger.info("============ Building seq2seq model - Decoder ...")
    decoder = Decoder(params, encoder)
    logger.info("")
    if params.lambda_dis != "0":
        logger.info("============ Building seq2seq model - Discriminator ...")
        discriminator = MultiAttrDiscriminator(params)
        logger.info("")
    else:
        discriminator = None

    # loss function for decoder reconstruction
    loss_weight = torch.FloatTensor(params.n_words).fill_(1)
    loss_weight[params.pad_index] = 0
    decoder.loss_fn = nn.CrossEntropyLoss(loss_weight, size_average=True)

    # language model
    if params.lambda_lm != "0":
        logger.info("============ Building seq2seq model - Language model ...")
        lm = LM(params, data['dico'])
        logger.info("")
    else:
        lm = None

    # cuda - models on CPU will be synchronized and don't need to be reloaded
    if cuda:
        encoder.cuda()
        decoder.cuda()
        if discriminator is not None:
            discriminator.cuda()
        if lm is not None:
            lm.cuda()

        # initialize the model with pretrained embeddings
        assert not (getattr(params, 'cpu_thread', False)) ^ (data is None)
        if data is not None:
            initialize_embeddings(encoder, decoder, params, data)

        # reload encoder / decoder / discriminator
        if params.reload_model != '':
            assert os.path.isfile(params.reload_model)
            logger.info("Reloading model from %s ..." % params.reload_model)
            reloaded = torch.load(params.reload_model)
            if params.reload_enc:
                logger.info("Reloading encoder...")
                reload_model(encoder, reloaded['enc'], encoder.ENC_ATTR)
            if params.reload_dec:
                logger.info("Reloading decoder...")
                reload_model(decoder, reloaded['dec'], decoder.DEC_ATTR)
            if params.reload_dis:
                logger.info("Reloading discriminator...")
                reload_model(discriminator, reloaded['dis'], discriminator.DIS_ATTR)

    # log models
    logger.info("============ Model summary")
    logger.info("Encoder: {}".format(encoder))
    logger.info("Decoder: {}".format(decoder))
    logger.info("Discriminator: {}".format(discriminator))
    logger.info("LM: {}".format(lm))
    logger.info("")

    return encoder, decoder, discriminator, lm