in custom/metrics.py [0:0]
def ranking_metrics(logits, true_token_logits, sample, ntokens, targets, topk=1, topp=0.0):
"""Compute summed metrics on a batch."""
negative_targets = (logits > true_token_logits[:, None]).float()
negative_targets_count = negative_targets.sum(dim=1)
target_rank = negative_targets_count.sum()
median_target_rank = negative_targets_count.median()
hits_at_1 = (negative_targets_count == 0).sum()
hits_at_10 = (negative_targets_count < 10).sum()
logging_output = {
'target_rank': utils.item(target_rank.data),
'hits_at_1': utils.item(hits_at_1.data),
'hits_at_10': utils.item(hits_at_10.data),
'median_target_rank': utils.item(median_target_rank), # NOTE: different normalization since it's not a sum
'normalizer': ntokens
}
for l in TrainingMetrics.REPEAT_CONTEXT_LENGTHS:
total_repeat_at_1, total_wrong_repeat_at_1, total_human_repeat_at_1 = \
TrainingMetrics.repeat_at_1(logits, targets, context_length=l)
temp = {'repeat_at_1/%d' % l: utils.item(total_repeat_at_1.data),
'wrong_repeat_at_1/%d' % l: utils.item(total_wrong_repeat_at_1.data),
'human_repeat_at_1/%d' % l: utils.item(total_human_repeat_at_1.data)
}
for k in temp:
logging_output[k] = temp[k]
if topk > 1:
filtered_topk = top_k_logits(logits, topk)
softmax_topk = F.softmax(filtered_topk, dim=1)
true_target_topk_probs = torch.gather(softmax_topk, index=targets[:, None], dim=1).sum()
logging_output['true_topk_{}_prob'.format(topk)] = true_target_topk_probs.item()
sum_topk_repeated_probs = 0
sum_topk_wrepeated_probs = 0
true_token_zeroed_topk_probs = softmax_topk.clone().scatter_(1, targets[:, None], 0)
for timestep in range(1, targets.size(0)):
prev_context = targets[max(0, timestep-128):timestep]
sum_topk_repeated_probs += torch.gather(softmax_topk[timestep], index=prev_context.unique(), dim=0).sum().item()
sum_topk_wrepeated_probs += torch.gather(true_token_zeroed_topk_probs[timestep], index=prev_context.unique(), dim=0).sum().item()
logging_output['repeat_topk_{}'.format(topk)] = sum_topk_repeated_probs
logging_output['wrepeat_topk_{}'.format(topk)] = sum_topk_wrepeated_probs
logging_output['nextunique_topk_{}'.format(topk)] = softmax_topk.multinomial(1).view(-1).tolist()
if topp > 0.0:
trimmed_topp = SequenceGenerator._sample_topp(SequenceGenerator, F.softmax(logits, dim=1), topp)
target_mask = (trimmed_topp[1] - targets[:, None].expand(-1, trimmed_topp[1].size(1))) == 0
true_target_topp_probs = torch.masked_select(trimmed_topp[0], target_mask).sum()
logging_output['true_topp_{}_prob'.format(topp)] = true_target_topp_probs.item()
sum_topp_repeated_probs = 0
sum_topp_wrepeated_probs = 0
true_token_zeroed_topp_probs = torch.masked_fill(trimmed_topp[0], target_mask, 0)
for timestep in range(1, targets.size(0)):
prev_context = targets[max(0, timestep-128):timestep]
topp_mask = (trimmed_topp[1][timestep][:, None] == prev_context[None, :]).sum(1).nonzero()
sum_topp_repeated_probs += torch.gather(trimmed_topp[0][timestep], index=topp_mask.view(-1), dim=0).sum().item()
sum_topp_wrepeated_probs += torch.gather(true_token_zeroed_topp_probs[timestep], index=topp_mask.view(-1), dim=0).sum().item()
logging_output['repeat_topp_{}'.format(topp)] = sum_topp_repeated_probs
logging_output['wrepeat_topp_{}'.format(topp)] = sum_topp_wrepeated_probs
logging_output['nextunique_topp_{}'.format(topp)] = torch.gather(trimmed_topp[1], index=trimmed_topp[0].multinomial(1), dim=1).view(-1).tolist()
return logging_output