def forward()

in custom/sequence_penalty_loss.py [0:0]


    def forward(self, model, sample, reduce=True, generator=None):
        seq_len = sample['net_input']['src_tokens'].size(1)

        # make total number of tokens equal to the sequence length (for memory purposes)
        n_batches = seq_len // (self.sequence_prefix_length + self.sequence_completion_length)
        batch = batch_input_sequence_by_prefix_length(sample['net_input']['src_tokens'],
                                                      prefix_length=self.sequence_prefix_length)
        batch = batch[:n_batches]

        pred_toks, lprobs = generator.generate_completion_greedy_training(model, batch,
                                                                          completion_length=self.sequence_completion_length)
        if self.sequence_candidate_type == 'repeat':
            mask = ngram_repeat_mask(pred_toks, self.sequence_ngram_n).type_as(lprobs)
        elif self.sequence_candidate_type == 'random':
            mask = torch.bernoulli(torch.zeros_like(pred_toks, dtype=torch.float).fill_(self.mask_p))

        pred_lprobs = lprobs.view(-1, lprobs.size(2)).gather(1, pred_toks.view(-1, 1))
        one_minus_probs = torch.clamp((1.0 - pred_lprobs.exp()), min=1e-20).view(pred_toks.size(0), pred_toks.size(1))
        loss = -torch.log(one_minus_probs)*mask
        loss = loss.sum()

        ntokens = pred_toks.numel()  # number of output tokens (tokens in completions)
        nsentences = batch.size(0)
        sample_size = ntokens
        logging_output = {
            'seq_loss': utils.item(loss.data),
            'seq_ntokens': ntokens,
            'seq_nsentences': nsentences,
            'seq_repeat_mask': utils.item(mask.sum().data),
            'seq_sample_size': sample_size,
        }

        # Sum each statistic, which will be normalized by the number of sentences in `aggregate_logging_outputs`.
        stats = defaultdict(float)
        for tok_list in pred_toks.cpu().tolist():
            ms = ngram_metrics(tok_list)
            for k, v in ms.items():
                stats[k] += v
        for k, v in stats.items():
            logging_output[k] = v

        return loss, sample_size, logging_output