def training_step()

in dpr_scale/task/dpr_task.py [0:0]


    def training_step(self, batch, batch_idx):
        """
        This receives queries, each with mutliple contexts.
        """
        query_ids = batch["query_ids"]  # bs x tokens
        contexts_ids = batch["contexts_ids"]  # ctx_cnt x ctx_len
        pos_ctx_indices = batch["pos_ctx_indices"]  # bs
        mask = batch["ctx_mask"]  # ctx_cnt
        query_repr, context_repr = self(query_ids, contexts_ids)  # bs

        if self.trainer.accelerator_connector.use_ddp:
            query_to_send = query_repr.detach()
            context_to_send = context_repr.detach()
            # assumes all nodes have same number of contexts
            all_query_repr, all_context_repr, all_labels, all_mask = self.all_gather(
                (query_to_send, context_to_send, pos_ctx_indices, mask)
            )
            offset = 0
            all_query_list = []
            all_context_list = []

            for i in range(all_labels.size(0)):
                if i != self.global_rank:
                    all_query_list.append(all_query_repr[i])
                    all_context_list.append(all_context_repr[i])
                else:
                    # to calculate grads for this node only
                    all_query_list.append(query_repr)
                    all_context_list.append(context_repr)
                all_labels[i] += offset
                offset += all_context_repr[i].size(0)

            context_repr = torch.cat(all_context_list, dim=0)  # total_ctx x dim
            query_repr = torch.cat(all_query_list, dim=0)  # total_query x dim
            pos_ctx_indices = torch.flatten(all_labels)  # total_query
            mask = torch.flatten(all_mask)  # total_ctx

        scores = self.sim_score(query_repr, context_repr, mask)
        loss = self.loss(scores, pos_ctx_indices)
        self.log("train_loss", loss, prog_bar=True)
        return loss