def __call__()

in curiosity/metrics.py [0:0]


    def __call__(self, logits: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor):
        """
        logits and labels should be the same shape. Labels should be
        an array of 0/1s to indicate if the document is relevant.

        We don't need a mask here since we select nonzero labels and
        masked entries in labels are never equal to 1 (Pedro is pretty sure)
        """
        n_relevent = labels.sum().item()
        if n_relevent == 0:
            # None are relevent, no-op
            return

        preds = logits.argsort(dim=-1, descending=True)
        # nonzeros occur where there are predictions to make
        # (n_nonzero, 3)
        # 3 = dims for batch, turn and fact
        indices = labels.nonzero()

        # TODO: This could be batched, but its a pain
        all_ranks = []
        # import ipdb; ipdb.set_trace()
        for batch_idx, turn_idx, fact_idx in indices:
            # List of predictions, first element is index
            # of top ranked document, second of second-top, etc
            inst_preds = preds[batch_idx, turn_idx]
            rank = (inst_preds == fact_idx).nonzero().reshape(-1)
            all_ranks.append(rank)
        all_ranks = torch.cat(all_ranks)
        # rank starts at zero from torch, += 1 for inversing it

        reciprocal_ranks = 1 / (1 + all_ranks).float()
        self._reciprocal_ranks.extend(reciprocal_ranks.cpu().numpy().tolist())
        return reciprocal_ranks.mean()