in code/src/model/seq2seq.py [0:0]
def build_seq2seq_model(params, data, cuda=True):
"""
Build a encoder / decoder, and the decoder reconstruction loss function.
"""
# encoder / decoder / discriminator
logger.info("============ Building seq2seq model - Encoder ...")
encoder = Encoder(params)
logger.info("")
logger.info("============ Building seq2seq model - Decoder ...")
decoder = Decoder(params, encoder)
logger.info("")
if params.lambda_dis != "0":
logger.info("============ Building seq2seq model - Discriminator ...")
discriminator = MultiAttrDiscriminator(params)
logger.info("")
else:
discriminator = None
# loss function for decoder reconstruction
loss_weight = torch.FloatTensor(params.n_words).fill_(1)
loss_weight[params.pad_index] = 0
decoder.loss_fn = nn.CrossEntropyLoss(loss_weight, size_average=True)
# language model
if params.lambda_lm != "0":
logger.info("============ Building seq2seq model - Language model ...")
lm = LM(params, data['dico'])
logger.info("")
else:
lm = None
# cuda - models on CPU will be synchronized and don't need to be reloaded
if cuda:
encoder.cuda()
decoder.cuda()
if discriminator is not None:
discriminator.cuda()
if lm is not None:
lm.cuda()
# initialize the model with pretrained embeddings
assert not (getattr(params, 'cpu_thread', False)) ^ (data is None)
if data is not None:
initialize_embeddings(encoder, decoder, params, data)
# reload encoder / decoder / discriminator
if params.reload_model != '':
assert os.path.isfile(params.reload_model)
logger.info("Reloading model from %s ..." % params.reload_model)
reloaded = torch.load(params.reload_model)
if params.reload_enc:
logger.info("Reloading encoder...")
reload_model(encoder, reloaded['enc'], encoder.ENC_ATTR)
if params.reload_dec:
logger.info("Reloading decoder...")
reload_model(decoder, reloaded['dec'], decoder.DEC_ATTR)
if params.reload_dis:
logger.info("Reloading discriminator...")
reload_model(discriminator, reloaded['dis'], discriminator.DIS_ATTR)
# log models
logger.info("============ Model summary")
logger.info("Encoder: {}".format(encoder))
logger.info("Decoder: {}".format(decoder))
logger.info("Discriminator: {}".format(discriminator))
logger.info("LM: {}".format(lm))
logger.info("")
return encoder, decoder, discriminator, lm