in custom/sequence_penalty_loss.py [0:0]
def forward(self, model, sample, reduce=True, generator=None):
seq_len = sample['net_input']['src_tokens'].size(1)
# make total number of tokens equal to the sequence length (for memory purposes)
n_batches = seq_len // (self.sequence_prefix_length + self.sequence_completion_length)
batch = batch_input_sequence_by_prefix_length(sample['net_input']['src_tokens'],
prefix_length=self.sequence_prefix_length)
batch = batch[:n_batches]
pred_toks, lprobs = generator.generate_completion_greedy_training(model, batch,
completion_length=self.sequence_completion_length)
if self.sequence_candidate_type == 'repeat':
mask = ngram_repeat_mask(pred_toks, self.sequence_ngram_n).type_as(lprobs)
elif self.sequence_candidate_type == 'random':
mask = torch.bernoulli(torch.zeros_like(pred_toks, dtype=torch.float).fill_(self.mask_p))
pred_lprobs = lprobs.view(-1, lprobs.size(2)).gather(1, pred_toks.view(-1, 1))
one_minus_probs = torch.clamp((1.0 - pred_lprobs.exp()), min=1e-20).view(pred_toks.size(0), pred_toks.size(1))
loss = -torch.log(one_minus_probs)*mask
loss = loss.sum()
ntokens = pred_toks.numel() # number of output tokens (tokens in completions)
nsentences = batch.size(0)
sample_size = ntokens
logging_output = {
'seq_loss': utils.item(loss.data),
'seq_ntokens': ntokens,
'seq_nsentences': nsentences,
'seq_repeat_mask': utils.item(mask.sum().data),
'seq_sample_size': sample_size,
}
# Sum each statistic, which will be normalized by the number of sentences in `aggregate_logging_outputs`.
stats = defaultdict(float)
for tok_list in pred_toks.cpu().tolist():
ms = ngram_metrics(tok_list)
for k, v in ms.items():
stats[k] += v
for k, v in stats.items():
logging_output[k] = v
return loss, sample_size, logging_output