in custom/gpt2/run_gpt2.py [0:0]
def ul_seq(model, batch, args):
input_sequence = batch[0].cuda()
batch = batch_input_sequence_by_prefix_length(input_sequence, args.prefix_length)
completions, continuation_logits = sample_sequence(model, batch,
args.prefix_length, args.continuation_length, args.top_k, args.top_p)
pred_toks = completions[:, args.prefix_length:].contiguous()
mask = ngram_repeat_mask(pred_toks, args.sequence_ngram_n).type_as(continuation_logits)
lprobs = F.log_softmax(continuation_logits, dim=-1)
pred_lprobs = lprobs.view(-1, lprobs.size(2)).gather(1, pred_toks.view(-1, 1))
one_minus_probs = torch.clamp((1.0 - pred_lprobs.exp()), min=1e-20).view(pred_toks.size(0), pred_toks.size(1))
loss = -torch.log(one_minus_probs) * mask
loss = loss.sum()
ntokens = pred_toks.numel() # number of output tokens (tokens in completions)
logging_output = {
'seq_loss': loss.item(),
'seq_sample_size': ntokens,
'seq_ntokens': ntokens,
'seq_nsentences': batch.size(0),
'seq_repeat_mask': mask.sum().item(),
}
# Sum each statistic, which will be normalized by the number of sentences in `aggregate_logging_outputs`.
stats = defaultdict(float)
for tok_list in pred_toks.cpu().tolist():
ms = ngram_metrics(tok_list)
for k, v in ms.items():
stats[k] += v
for k, v in stats.items():
logging_output[k] = v
loss = loss / ntokens
return loss, logging_output