in dpr_scale/task/dpr_task.py [0:0]
def _eval_epoch_end(self, outputs, log_prefix="valid"):
total_avg_rank, total_ctx_count, total_count = 0, 0, 0
total_mrr = 0
total_loss = 0
total_score = 0
if self.in_batch_eval:
for metrics, query_repr, contexts_repr, _, mask, loss in outputs:
rank, mrr, score = metrics
total_avg_rank += rank
total_mrr += mrr
total_score += score
total_ctx_count += contexts_repr.size(0) - torch.sum(mask)
total_count += query_repr.size(0)
total_loss += loss
total_ctx_count = total_ctx_count / len(outputs)
else:
# collate the representation and gold +ve labels
all_query_repr = []
all_context_repr = []
all_labels = []
all_mask = []
offset = 0
for _, query_repr, context_repr, target_labels, mask, _ in outputs:
all_query_repr.append(query_repr)
all_context_repr.append(context_repr)
all_mask.append(mask)
all_labels.extend([offset + x for x in target_labels])
offset += context_repr.size(0)
# gather all contexts
all_context_repr = torch.cat(all_context_repr, dim=0)
all_mask = torch.cat(all_mask, dim=0)
if self.trainer.accelerator_connector.use_ddp:
all_context_repr, all_mask = self.all_gather(
(all_context_repr, all_mask)
)
all_labels = [
x + all_context_repr.size(1) * self.global_rank for x in all_labels
]
all_context_repr = torch.cat(tuple(all_context_repr), dim=0)
all_mask = torch.cat(tuple(all_mask), dim=0)
all_query_repr = torch.cat(all_query_repr, dim=0)
scores = self.sim_score(all_query_repr, all_context_repr, all_mask)
total_count = all_query_repr.size(0)
total_ctx_count = scores.size(1) - torch.sum(all_mask)
total_avg_rank, total_mrr, total_score = self.compute_rank_metrics(
scores, all_labels
)
total_loss = self.loss(
scores,
torch.tensor(all_labels).to(scores.device, dtype=torch.long),
)
metrics = {
log_prefix + "_avg_rank": total_avg_rank / total_count,
log_prefix + "_mrr": total_mrr / total_count,
log_prefix + f"_accuracy@{self.k}": total_score / total_count,
log_prefix + "_ctx_count": total_ctx_count,
log_prefix + "_loss": total_loss,
}
self.log_dict(metrics, on_epoch=True, sync_dist=True)