def _eval_epoch_end()

in dpr_scale/task/dpr_task.py [0:0]


    def _eval_epoch_end(self, outputs, log_prefix="valid"):
        total_avg_rank, total_ctx_count, total_count = 0, 0, 0
        total_mrr = 0
        total_loss = 0
        total_score = 0
        if self.in_batch_eval:
            for metrics, query_repr, contexts_repr, _, mask, loss in outputs:
                rank, mrr, score = metrics
                total_avg_rank += rank
                total_mrr += mrr
                total_score += score
                total_ctx_count += contexts_repr.size(0) - torch.sum(mask)
                total_count += query_repr.size(0)
                total_loss += loss
            total_ctx_count = total_ctx_count / len(outputs)
        else:
            # collate the representation and gold +ve labels
            all_query_repr = []
            all_context_repr = []
            all_labels = []
            all_mask = []
            offset = 0
            for _, query_repr, context_repr, target_labels, mask, _ in outputs:
                all_query_repr.append(query_repr)
                all_context_repr.append(context_repr)
                all_mask.append(mask)
                all_labels.extend([offset + x for x in target_labels])
                offset += context_repr.size(0)
            # gather all contexts
            all_context_repr = torch.cat(all_context_repr, dim=0)
            all_mask = torch.cat(all_mask, dim=0)
            if self.trainer.accelerator_connector.use_ddp:
                all_context_repr, all_mask = self.all_gather(
                    (all_context_repr, all_mask)
                )
                all_labels = [
                    x + all_context_repr.size(1) * self.global_rank for x in all_labels
                ]
                all_context_repr = torch.cat(tuple(all_context_repr), dim=0)
                all_mask = torch.cat(tuple(all_mask), dim=0)
            all_query_repr = torch.cat(all_query_repr, dim=0)
            scores = self.sim_score(all_query_repr, all_context_repr, all_mask)
            total_count = all_query_repr.size(0)
            total_ctx_count = scores.size(1) - torch.sum(all_mask)
            total_avg_rank, total_mrr, total_score = self.compute_rank_metrics(
                scores, all_labels
            )
            total_loss = self.loss(
                scores,
                torch.tensor(all_labels).to(scores.device, dtype=torch.long),
            )
        metrics = {
            log_prefix + "_avg_rank": total_avg_rank / total_count,
            log_prefix + "_mrr": total_mrr / total_count,
            log_prefix + f"_accuracy@{self.k}": total_score / total_count,
            log_prefix + "_ctx_count": total_ctx_count,
            log_prefix + "_loss": total_loss,
        }
        self.log_dict(metrics, on_epoch=True, sync_dist=True)