in NMT/src/model/attention.py [0:0]
def build_attention_model(params, data, cuda=True):
"""
Build a encoder / decoder, and the decoder reconstruction loss function.
"""
# encoder / decoder / discriminator
if params.transformer:
encoder, decoder = build_transformer_enc_dec(params)
else:
encoder, decoder = build_lstm_enc_dec(params)
if params.lambda_dis not in ["0", "-1"]:
logger.info("============ Building attention model - Discriminator ...")
discriminator = Discriminator(params)
logger.info("")
else:
discriminator = None
# loss function for decoder reconstruction
loss_fn = []
for n_words in params.n_words:
loss_weight = torch.FloatTensor(n_words).fill_(1)
loss_weight[params.pad_index] = 0
if params.label_smoothing <= 0:
loss_fn.append(nn.CrossEntropyLoss(loss_weight, size_average=True))
else:
loss_fn.append(LabelSmoothedCrossEntropyLoss(
params.label_smoothing,
params.pad_index,
size_average=True,
weight=loss_weight,
))
decoder.loss_fn = nn.ModuleList(loss_fn)
# language model
if params.lambda_lm not in ["0", "-1"]:
logger.info("============ Building attention model - Language model ...")
lm = LM(params, encoder, decoder)
logger.info("")
else:
lm = None
# cuda - models on CPU will be synchronized and don't need to be reloaded
if cuda:
encoder.cuda()
decoder.cuda()
if len(params.vocab) > 0:
decoder.vocab_mask_neg = [x.cuda() for x in decoder.vocab_mask_neg]
if discriminator is not None:
discriminator.cuda()
if lm is not None:
lm.cuda()
# initialize the model with pretrained embeddings
assert not (getattr(params, 'cpu_thread', False)) ^ (data is None)
if data is not None:
initialize_embeddings(encoder, decoder, params, data)
# reload encoder / decoder / discriminator
if params.reload_model != '':
assert os.path.isfile(params.reload_model)
logger.info("Reloading model from %s ..." % params.reload_model)
reloaded = torch.load(params.reload_model)
if params.reload_enc:
logger.info("Reloading encoder...")
enc = reloaded.get('enc', reloaded.get('encoder'))
reload_model(encoder, enc, encoder.ENC_ATTR)
if params.reload_dec:
logger.info("Reloading decoder...")
dec = reloaded.get('dec', reloaded.get('decoder'))
reload_model(decoder, dec, decoder.DEC_ATTR)
if params.reload_dis:
assert discriminator is not None
logger.info("Reloading discriminator...")
dis = reloaded.get('dis', reloaded.get('discriminator'))
reload_model(discriminator, dis, discriminator.DIS_ATTR)
# log models
encdec_params = set(
p
for module in [encoder, decoder]
for p in module.parameters()
if p.requires_grad
)
num_encdec_params = sum(p.numel() for p in encdec_params)
logger.info("============ Model summary")
logger.info("Number of enc+dec parameters: {}".format(num_encdec_params))
logger.info("Encoder: {}".format(encoder))
logger.info("Decoder: {}".format(decoder))
logger.info("Discriminator: {}".format(discriminator))
logger.info("LM: {}".format(lm))
logger.info("")
return encoder, decoder, discriminator, lm