in mdr/qa/qa_dataset.py [0:0]
def __init__(self,
tokenizer,
data_path,
max_seq_len,
max_q_len,
train=False,
no_sent_label=False
):
retriever_outputs = [json.loads(l) for l in tqdm(open(data_path).readlines())]
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.max_q_len = max_q_len
self.train = train
self.no_sent_label = no_sent_label
self.simple_tok = SimpleTokenizer()
self.data = []
if train:
self.qid2gold = collections.defaultdict(list) # idx
self.qid2neg = collections.defaultdict(list)
for item in retriever_outputs:
if item["question"].endswith("?"):
item["question"] = item["question"][:-1]
sp_sent_labels = []
sp_gold = []
if not self.no_sent_label:
for sp in item["sp"]:
for _ in sp["sp_sent_ids"]:
sp_gold.append([sp["title"], _])
for idx in range(len(sp["sents"])):
sp_sent_labels.append(int(idx in sp["sp_sent_ids"]))
question_type = item["type"]
self.data.append({
"question": item["question"],
"passages": item["sp"],
"label": 1,
"qid": item["_id"],
"gold_answer": item["answer"],
"sp_sent_labels": sp_sent_labels,
"ans_covered": 1, # includes partial chains.
"sp_gold": sp_gold
})
self.qid2gold[item["_id"]].append(len(self.data) - 1)
sp_titles = set([_["title"] for _ in item["sp"]])
if question_type == "bridge":
ans_titles = set([p["title"] for p in item["sp"] if para_has_answer(item["answer"], "".join(p["sents"]), self.simple_tok)])
else:
ans_titles = set()
# top ranked negative chains
ds_count = 0 # track how many distant supervised chain to use
ds_limit = 5
for chain in item["candidate_chains"]:
chain_titles = [_["title"] for _ in chain]
if set(chain_titles) == sp_titles:
continue
if question_type == "bridge":
answer_covered = int(len(set(chain_titles) & ans_titles) > 0)
ds_count += answer_covered
else:
answer_covered = 0
self.data.append({
"question": item["question"],
"passages": chain,
"label": 0,
"qid": item["_id"],
"gold_answer": item["answer"],
"ans_covered": answer_covered,
"sp_gold": sp_gold
})
self.qid2neg[item["_id"]].append(len(self.data) - 1)
else:
for item in retriever_outputs:
if item["question"].endswith("?"):
item["question"] = item["question"][:-1]
# for validation, add target predictions
sp_titles = set([_["title"] for _ in item["sp"]]) if "sp" in item else None
gold_answer = item.get("answer", [])
sp_gold = []
if "sp" in item:
for sp in item["sp"]:
for _ in sp["sp_sent_ids"]:
sp_gold.append([sp["title"], _])
chain_seen = set()
for chain in item["candidate_chains"]:
chain_titles = [_["title"] for _ in chain]
# title_set = frozenset(chain_titles)
# if len(title_set) == 0 or title_set in chain_seen:
# continue
# chain_seen.add(title_set)
if sp_titles:
label = int(set(chain_titles) == sp_titles)
else:
label = -1
self.data.append({
"question": item["question"],
"passages": chain,
"label": label,
"qid": item["_id"],
"gold_answer": gold_answer,
"sp_gold": sp_gold
})
print(f"Data size {len(self.data)}")