in code/src/model/attention.py [0:0]
def build_attention_model(params, data, cuda=True):
"""
Build a encoder / decoder, and the decoder reconstruction loss function.
"""
# encoder / decoder / discriminator
if params.transformer:
encoder, decoder = build_transformer_enc_dec(params)
else:
encoder, decoder = build_lstm_enc_dec(params)
if params.lambda_dis != "0":
logger.info("============ Building attention model - Discriminator ...")
if params.disc_lstm_dim > 0:
assert params.disc_lstm_layers >= 1
discriminator = MultiAttrDiscriminatorLSTM(params)
else:
discriminator = MultiAttrDiscriminator(params)
logger.info("")
else:
discriminator = None
# loss function for decoder reconstruction
loss_weight = torch.FloatTensor(params.n_words).fill_(1)
loss_weight[params.pad_index] = 0
if params.label_smoothing <= 0:
decoder.loss_fn = nn.CrossEntropyLoss(loss_weight, size_average=True)
else:
decoder.loss_fn = LabelSmoothedCrossEntropyLoss(
params.label_smoothing,
params.pad_index,
size_average=True,
weight=loss_weight,
)
# language model
if params.lambda_lm != "0":
logger.info("============ Building attention model - Language model ...")
lm = LM(params, data['dico'])
logger.info("")
else:
lm = None
# cuda - models on CPU will be synchronized and don't need to be reloaded
if cuda:
encoder.cuda()
decoder.cuda()
if discriminator is not None:
discriminator.cuda()
if lm is not None:
lm.cuda()
# initialize the model with pretrained embeddings
assert not (getattr(params, 'cpu_thread', False)) ^ (data is None)
if data is not None:
initialize_embeddings(encoder, decoder, params, data)
# reload encoder / decoder / discriminator
if params.reload_model != '':
assert os.path.isfile(params.reload_model)
logger.info("Reloading model from %s ..." % params.reload_model)
reloaded = torch.load(params.reload_model)
if params.reload_enc:
logger.info("Reloading encoder...")
enc = reloaded.get('enc', reloaded.get('encoder'))
reload_model(encoder, enc, encoder.ENC_ATTR)
if params.reload_dec:
logger.info("Reloading decoder...")
dec = reloaded.get('dec', reloaded.get('decoder'))
reload_model(decoder, dec, decoder.DEC_ATTR)
if params.reload_dis:
assert discriminator is not None
logger.info("Reloading discriminator...")
dis = reloaded.get('dis', reloaded.get('discriminator'))
reload_model(discriminator, dis, discriminator.DIS_ATTR)
# log models
encdec_params = set(
p
for module in [encoder, decoder]
for p in module.parameters()
if p.requires_grad
)
num_encdec_params = sum(p.numel() for p in encdec_params)
logger.info("============ Model summary")
logger.info("Number of enc+dec parameters: {}".format(num_encdec_params))
logger.info("Encoder: {}".format(encoder))
logger.info("Decoder: {}".format(decoder))
logger.info("Discriminator: {}".format(discriminator))
logger.info("LM: {}".format(lm))
logger.info("")
return encoder, decoder, discriminator, lm