def forward()

in dpr_scale/transforms/dpr_transform.py [0:0]


    def forward(self, batch, stage="train"):
        """
        Combines pos and neg contexts. Samples randomly limited number of pos/neg
        contexts if stage == "train". Then concatenates them for cross attention
        training along with the labels.
        """
        all_ctxs = []
        all_labels = []
        rows = batch if type(batch) is list else batch[self.text_column]
        neg_candidates = []
        for row in rows:
            # collect candiates for random in-batch negatives
            row = ujson.loads(row)
            neg_candidates.extend(row["positive_ctxs"])
            neg_candidates.extend(row["hard_negative_ctxs"])
        for row in rows:
            row = ujson.loads(row)
            # also support DPR output format
            if "positive_ctxs" not in row and "ctxs" in row:
                row["positive_ctxs"] = []
                row["hard_negative_ctxs"] = []
                for ctx in row["ctxs"]:
                    if ctx["has_answer"]:
                        row["positive_ctxs"].append(ctx)
                    else:
                        row["hard_negative_ctxs"].append(ctx)
                if not row["positive_ctxs"]:
                    row["positive_ctxs"].append(row["ctxs"][0])

            # sample positive contexts
            contexts_pos = row["positive_ctxs"]

            # Handle case when context is a list of tokens instead of string.
            try:
                assert isinstance(contexts_pos[0]["text"], str)
            except AssertionError:
                for c in contexts_pos:
                    c["text"] = " ".join(c["text"])

            if stage == "train" and self.pos_ctx_sample:
                contexts_pos = random.sample(
                    contexts_pos, min(len(contexts_pos), self.num_positive)
                )
            else:
                contexts_pos = contexts_pos[: self.num_positive]

            # sample negative contexts
            contexts_neg = row["hard_negative_ctxs"]
            num_random_negs = 0
            if stage == "train":
                num_neg_sample = self.num_negative
                num_random_negs = self.num_random_negs
            elif stage == "eval":
                num_neg_sample = self.num_val_negative
            elif stage == "test":
                num_neg_sample = self.num_test_negative

            if num_neg_sample > 0:
                if (
                    stage == "train"
                    and self.neg_ctx_sample
                    and len(contexts_neg) > num_neg_sample
                ):
                    contexts_neg = random.sample(contexts_neg, num_neg_sample)
                else:
                    contexts_neg = contexts_neg[:num_neg_sample]
            else:
                contexts_neg = []
            # Concat texts with sep token
            ctxs = contexts_pos + contexts_neg
            if len(contexts_neg) < num_neg_sample + num_random_negs:
                # add dummy ctxs
                ctxs.extend(
                    random.sample(neg_candidates, (num_neg_sample + num_random_negs - len(contexts_neg)))
                )
            concat_ctxs = [
                maybe_add_title(ctx["text"], row["question"], True, self.sep_token)
                for ctx in ctxs
            ]
            all_ctxs.extend(concat_ctxs)
            all_labels.append("0")

        return self.text_transform(
            {
                "text": all_ctxs,
                "label": all_labels,
            }
        )