in code/src/trainer.py [0:0]
def otf_bt(self, batch, lambda_xe, backprop_temperature):
"""
On the fly back-translation.
"""
params = self.params
sent1, len1, attr1 = batch['sent1'], batch['len1'], batch['attr1']
sent2, len2, attr2 = batch['sent2'], batch['len2'], batch['attr2']
sent3, len3, attr3 = batch['sent3'], batch['len3'], batch['attr3']
if lambda_xe == 0:
logger.warning("Unused generated CPU batch!")
return
n_words2 = params.n_words
n_words3 = params.n_words
self.encoder.train()
self.decoder.train()
# prepare batch
sent1, sent2, sent3 = sent1.cuda(), sent2.cuda(), sent3.cuda()
attr1, attr2, attr3 = attr1.cuda(), attr2.cuda(), attr3.cuda()
bs = sent1.size(1)
if backprop_temperature == -1:
# attr2 -> attr3
encoded = self.encoder(sent2, len2)
else:
raise Exception("Not implemented for attributes yet! Need to add attribute embedding below.")
# attr1 -> attr2
encoded = self.encoder(sent1, len1)
scores = self.decoder(encoded, sent2[:-1], attr2)
assert scores.size() == (len2.max() - 1, bs, n_words2)
# attr2 -> attr3
bos = torch.cuda.FloatTensor(1, bs, n_words2).zero_()
bos[0, :, params.bos_index] = 1
sent2_input = torch.cat([bos, F.softmax(scores / backprop_temperature, -1)], 0)
encoded = self.encoder(sent2_input, len2)
# cross-entropy scores / loss
scores = self.decoder(encoded, sent3[:-1], attr3)
xe_loss = self.decoder.loss_fn(scores.view(-1, n_words3), sent3[1:].view(-1))
self.stats['xe_bt'].append(xe_loss.item())
assert lambda_xe > 0
loss = lambda_xe * xe_loss
# check NaN
if (loss != loss).data.any():
logger.error("NaN detected")
exit()
# optimizer
assert params.otf_update_enc or params.otf_update_dec
to_update = []
if params.otf_update_enc:
to_update.append('enc')
if params.otf_update_dec:
to_update.append('dec')
self.zero_grad(to_update)
loss.backward()
self.update_params(to_update)
# number of processed sentences / words
self.stats['processed_s'] += len3.size(0)
self.stats['processed_w'] += len3.sum()