in dpr_scale/transforms/dpr_transform.py [0:0]
def forward(self, batch, stage="train"):
"""
Combines pos and neg contexts. Samples randomly limited number of pos/neg
contexts if stage == "train". Also ensures we have exactly num_negative
contexts by padding with fake contexts if the data did not have enough.
A boolean mask is created to ignore these fake contexts when training.
"""
questions = []
all_ctxs = []
positive_ctx_indices = []
ctx_mask = []
rows = batch if type(batch) is list else batch[self.text_column]
for row in rows:
row = ujson.loads(row)
# also support DPR output format
if "positive_ctxs" not in row and "ctxs" in row:
row["positive_ctxs"] = []
row["hard_negative_ctxs"] = []
for ctx in row["ctxs"]:
if ctx["has_answer"]:
row["positive_ctxs"].append(ctx)
else:
row["hard_negative_ctxs"].append(ctx)
if not row["positive_ctxs"]:
row["positive_ctxs"].append(row["ctxs"][0])
# sample positive contexts
contexts_pos = row["positive_ctxs"]
# Handle case when context is a list of tokens instead of string.
try:
if len(contexts_pos) > 0:
assert isinstance(contexts_pos[0]["text"], str)
except AssertionError:
for c in contexts_pos:
c["text"] = " ".join(c["text"])
if stage == "train" and self.pos_ctx_sample:
contexts_pos = random.sample(
contexts_pos, min(len(contexts_pos), self.num_positive)
)
else:
contexts_pos = contexts_pos[: self.num_positive]
# sample negative contexts
contexts_neg = row["hard_negative_ctxs"]
if stage == "train":
num_neg_sample = self.num_negative
elif stage == "eval":
num_neg_sample = self.num_val_negative
elif stage == "test":
num_neg_sample = self.num_test_negative
if num_neg_sample > 0:
if (
stage == "train"
and self.neg_ctx_sample
and len(contexts_neg) > num_neg_sample
):
contexts_neg = random.sample(contexts_neg, num_neg_sample)
else:
contexts_neg = contexts_neg[:num_neg_sample]
else:
contexts_neg = []
ctxs = contexts_pos + contexts_neg
# pad this up to num_neg_sample contexts
mask = [0] * len(ctxs)
if len(contexts_neg) < num_neg_sample:
# add dummy ctxs
ctxs.extend(
[{"text": "0", "title": "0"}] * (num_neg_sample - len(contexts_neg))
)
mask.extend([1] * (num_neg_sample - len(contexts_neg)))
current_ctxs_len = len(all_ctxs)
all_ctxs.extend(ctxs)
positive_ctx_indices.append(current_ctxs_len)
questions.append(row["question"])
ctx_mask.extend(mask)
ctx_text = [
maybe_add_title(x["text"], x["title"], self.use_title, self.sep_token)
for x in all_ctxs
]
question_tensors = self._transform(questions)
ctx_tensors = self._transform(ctx_text)
return {
"query_ids": question_tensors,
"contexts_ids": ctx_tensors,
"pos_ctx_indices": torch.tensor(positive_ctx_indices, dtype=torch.long),
"ctx_mask": torch.tensor(ctx_mask, dtype=torch.bool),
}