in curiosity/metrics.py [0:0]
def __call__(self, logits: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor):
"""
logits and labels should be the same shape. Labels should be
an array of 0/1s to indicate if the document is relevant.
We don't need a mask here since we select nonzero labels and
masked entries in labels are never equal to 1 (Pedro is pretty sure)
"""
n_relevent = labels.sum().item()
if n_relevent == 0:
# None are relevent, no-op
return
preds = logits.argsort(dim=-1, descending=True)
# nonzeros occur where there are predictions to make
# (n_nonzero, 3)
# 3 = dims for batch, turn and fact
indices = labels.nonzero()
# TODO: This could be batched, but its a pain
all_ranks = []
# import ipdb; ipdb.set_trace()
for batch_idx, turn_idx, fact_idx in indices:
# List of predictions, first element is index
# of top ranked document, second of second-top, etc
inst_preds = preds[batch_idx, turn_idx]
rank = (inst_preds == fact_idx).nonzero().reshape(-1)
all_ranks.append(rank)
all_ranks = torch.cat(all_ranks)
# rank starts at zero from torch, += 1 for inversing it
reciprocal_ranks = 1 / (1 + all_ranks).float()
self._reciprocal_ranks.extend(reciprocal_ranks.cpu().numpy().tolist())
return reciprocal_ranks.mean()