in dpr_scale/transforms/dpr_transform.py [0:0]
def forward(self, batch, stage="train"):
"""
Combines pos and neg contexts. Samples randomly limited number of pos/neg
contexts if stage == "train". Then concatenates them for cross attention
training along with the labels.
"""
all_ctxs = []
all_labels = []
rows = batch if type(batch) is list else batch[self.text_column]
neg_candidates = []
for row in rows:
# collect candiates for random in-batch negatives
row = ujson.loads(row)
neg_candidates.extend(row["positive_ctxs"])
neg_candidates.extend(row["hard_negative_ctxs"])
for row in rows:
row = ujson.loads(row)
# also support DPR output format
if "positive_ctxs" not in row and "ctxs" in row:
row["positive_ctxs"] = []
row["hard_negative_ctxs"] = []
for ctx in row["ctxs"]:
if ctx["has_answer"]:
row["positive_ctxs"].append(ctx)
else:
row["hard_negative_ctxs"].append(ctx)
if not row["positive_ctxs"]:
row["positive_ctxs"].append(row["ctxs"][0])
# sample positive contexts
contexts_pos = row["positive_ctxs"]
# Handle case when context is a list of tokens instead of string.
try:
assert isinstance(contexts_pos[0]["text"], str)
except AssertionError:
for c in contexts_pos:
c["text"] = " ".join(c["text"])
if stage == "train" and self.pos_ctx_sample:
contexts_pos = random.sample(
contexts_pos, min(len(contexts_pos), self.num_positive)
)
else:
contexts_pos = contexts_pos[: self.num_positive]
# sample negative contexts
contexts_neg = row["hard_negative_ctxs"]
num_random_negs = 0
if stage == "train":
num_neg_sample = self.num_negative
num_random_negs = self.num_random_negs
elif stage == "eval":
num_neg_sample = self.num_val_negative
elif stage == "test":
num_neg_sample = self.num_test_negative
if num_neg_sample > 0:
if (
stage == "train"
and self.neg_ctx_sample
and len(contexts_neg) > num_neg_sample
):
contexts_neg = random.sample(contexts_neg, num_neg_sample)
else:
contexts_neg = contexts_neg[:num_neg_sample]
else:
contexts_neg = []
# Concat texts with sep token
ctxs = contexts_pos + contexts_neg
if len(contexts_neg) < num_neg_sample + num_random_negs:
# add dummy ctxs
ctxs.extend(
random.sample(neg_candidates, (num_neg_sample + num_random_negs - len(contexts_neg)))
)
concat_ctxs = [
maybe_add_title(ctx["text"], row["question"], True, self.sep_token)
for ctx in ctxs
]
all_ctxs.extend(concat_ctxs)
all_labels.append("0")
return self.text_transform(
{
"text": all_ctxs,
"label": all_labels,
}
)