def otf_bt()

in code/src/trainer.py [0:0]


    def otf_bt(self, batch, lambda_xe, backprop_temperature):
        """
        On the fly back-translation.
        """
        params = self.params
        sent1, len1, attr1 = batch['sent1'], batch['len1'], batch['attr1']
        sent2, len2, attr2 = batch['sent2'], batch['len2'], batch['attr2']
        sent3, len3, attr3 = batch['sent3'], batch['len3'], batch['attr3']

        if lambda_xe == 0:
            logger.warning("Unused generated CPU batch!")
            return

        n_words2 = params.n_words
        n_words3 = params.n_words
        self.encoder.train()
        self.decoder.train()

        # prepare batch
        sent1, sent2, sent3 = sent1.cuda(), sent2.cuda(), sent3.cuda()
        attr1, attr2, attr3 = attr1.cuda(), attr2.cuda(), attr3.cuda()
        bs = sent1.size(1)

        if backprop_temperature == -1:
            # attr2 -> attr3
            encoded = self.encoder(sent2, len2)
        else:
            raise Exception("Not implemented for attributes yet! Need to add attribute embedding below.")
            # attr1 -> attr2
            encoded = self.encoder(sent1, len1)
            scores = self.decoder(encoded, sent2[:-1], attr2)
            assert scores.size() == (len2.max() - 1, bs, n_words2)

            # attr2 -> attr3
            bos = torch.cuda.FloatTensor(1, bs, n_words2).zero_()
            bos[0, :, params.bos_index] = 1
            sent2_input = torch.cat([bos, F.softmax(scores / backprop_temperature, -1)], 0)
            encoded = self.encoder(sent2_input, len2)

        # cross-entropy scores / loss
        scores = self.decoder(encoded, sent3[:-1], attr3)
        xe_loss = self.decoder.loss_fn(scores.view(-1, n_words3), sent3[1:].view(-1))
        self.stats['xe_bt'].append(xe_loss.item())
        assert lambda_xe > 0
        loss = lambda_xe * xe_loss

        # check NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            exit()

        # optimizer
        assert params.otf_update_enc or params.otf_update_dec
        to_update = []
        if params.otf_update_enc:
            to_update.append('enc')
        if params.otf_update_dec:
            to_update.append('dec')
        self.zero_grad(to_update)
        loss.backward()
        self.update_params(to_update)

        # number of processed sentences / words
        self.stats['processed_s'] += len3.size(0)
        self.stats['processed_w'] += len3.sum()