def forward()

in dpr_scale/transforms/dpr_transform.py [0:0]


    def forward(self, batch, stage="train"):
        """
        Combines pos and neg contexts. Samples randomly limited number of pos/neg
        contexts if stage == "train". Also ensures we have exactly num_negative
        contexts by padding with fake contexts if the data did not have enough.
        A boolean mask is created to ignore these fake contexts when training.
        """
        questions = []
        all_ctxs = []
        positive_ctx_indices = []
        ctx_mask = []
        rows = batch if type(batch) is list else batch[self.text_column]
        for row in rows:
            row = ujson.loads(row)
            # also support DPR output format
            if "positive_ctxs" not in row and "ctxs" in row:
                row["positive_ctxs"] = []
                row["hard_negative_ctxs"] = []
                for ctx in row["ctxs"]:
                    if ctx["has_answer"]:
                        row["positive_ctxs"].append(ctx)
                    else:
                        row["hard_negative_ctxs"].append(ctx)
                if not row["positive_ctxs"]:
                    row["positive_ctxs"].append(row["ctxs"][0])

            # sample positive contexts
            contexts_pos = row["positive_ctxs"]

            # Handle case when context is a list of tokens instead of string.
            try:
                if len(contexts_pos) > 0:
                    assert isinstance(contexts_pos[0]["text"], str)
            except AssertionError:
                for c in contexts_pos:
                    c["text"] = " ".join(c["text"])

            if stage == "train" and self.pos_ctx_sample:
                contexts_pos = random.sample(
                    contexts_pos, min(len(contexts_pos), self.num_positive)
                )
            else:
                contexts_pos = contexts_pos[: self.num_positive]

            # sample negative contexts
            contexts_neg = row["hard_negative_ctxs"]
            if stage == "train":
                num_neg_sample = self.num_negative
            elif stage == "eval":
                num_neg_sample = self.num_val_negative
            elif stage == "test":
                num_neg_sample = self.num_test_negative

            if num_neg_sample > 0:
                if (
                    stage == "train"
                    and self.neg_ctx_sample
                    and len(contexts_neg) > num_neg_sample
                ):
                    contexts_neg = random.sample(contexts_neg, num_neg_sample)
                else:
                    contexts_neg = contexts_neg[:num_neg_sample]
            else:
                contexts_neg = []

            ctxs = contexts_pos + contexts_neg
            # pad this up to num_neg_sample contexts
            mask = [0] * len(ctxs)
            if len(contexts_neg) < num_neg_sample:
                # add dummy ctxs
                ctxs.extend(
                    [{"text": "0", "title": "0"}] * (num_neg_sample - len(contexts_neg))
                )
                mask.extend([1] * (num_neg_sample - len(contexts_neg)))

            current_ctxs_len = len(all_ctxs)
            all_ctxs.extend(ctxs)
            positive_ctx_indices.append(current_ctxs_len)
            questions.append(row["question"])
            ctx_mask.extend(mask)

        ctx_text = [
            maybe_add_title(x["text"], x["title"], self.use_title, self.sep_token)
            for x in all_ctxs
        ]
        question_tensors = self._transform(questions)
        ctx_tensors = self._transform(ctx_text)
        return {
            "query_ids": question_tensors,
            "contexts_ids": ctx_tensors,
            "pos_ctx_indices": torch.tensor(positive_ctx_indices, dtype=torch.long),
            "ctx_mask": torch.tensor(ctx_mask, dtype=torch.bool),
        }