def __init__()

in mdr/qa/qa_dataset.py [0:0]


    def __init__(self,
        tokenizer,
        data_path,
        max_seq_len,
        max_q_len,
        train=False,
        no_sent_label=False
        ):

        retriever_outputs = [json.loads(l) for l in tqdm(open(data_path).readlines())]
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.max_q_len = max_q_len
        self.train = train
        self.no_sent_label = no_sent_label
        self.simple_tok = SimpleTokenizer()
        self.data = []

        if train:
            self.qid2gold = collections.defaultdict(list) # idx 
            self.qid2neg = collections.defaultdict(list)
            for item in retriever_outputs:
                if item["question"].endswith("?"):
                    item["question"] = item["question"][:-1]

                sp_sent_labels = []
                sp_gold = []
                if not self.no_sent_label:
                    for sp in item["sp"]:
                        for _ in sp["sp_sent_ids"]:
                            sp_gold.append([sp["title"], _])
                        for idx in range(len(sp["sents"])):
                            sp_sent_labels.append(int(idx in sp["sp_sent_ids"]))

                question_type = item["type"]
                self.data.append({
                    "question": item["question"],
                    "passages": item["sp"], 
                    "label": 1,
                    "qid": item["_id"],
                    "gold_answer": item["answer"],
                    "sp_sent_labels": sp_sent_labels,
                    "ans_covered": 1, # includes partial chains.
                    "sp_gold": sp_gold
                })
                self.qid2gold[item["_id"]].append(len(self.data) - 1)

                sp_titles = set([_["title"] for _ in item["sp"]])
                if question_type == "bridge":
                    ans_titles = set([p["title"] for p in item["sp"] if para_has_answer(item["answer"], "".join(p["sents"]), self.simple_tok)])
                else:
                    ans_titles = set()
                # top ranked negative chains
                ds_count = 0 # track how many distant supervised chain to use
                ds_limit = 5
                for chain in item["candidate_chains"]:
                    chain_titles = [_["title"] for _ in chain]
                    if set(chain_titles) == sp_titles:
                        continue
                    if question_type == "bridge":
                        answer_covered = int(len(set(chain_titles) & ans_titles) > 0)
                        ds_count += answer_covered
                    else:
                        answer_covered = 0
                    self.data.append({
                        "question": item["question"],
                        "passages": chain,
                        "label": 0,
                        "qid": item["_id"],
                        "gold_answer": item["answer"],
                        "ans_covered": answer_covered,
                        "sp_gold": sp_gold
                    })
                    self.qid2neg[item["_id"]].append(len(self.data) - 1)
        else:
            for item in retriever_outputs:
                if item["question"].endswith("?"):
                    item["question"] = item["question"][:-1]

                # for validation, add target predictions
                sp_titles = set([_["title"] for _ in item["sp"]]) if "sp" in item else None
                gold_answer = item.get("answer", [])
                sp_gold = []
                if "sp" in item:
                    for sp in item["sp"]:
                        for _ in sp["sp_sent_ids"]:
                            sp_gold.append([sp["title"], _])

                chain_seen = set()
                for chain in item["candidate_chains"]:
                    chain_titles = [_["title"] for _ in chain]

                    # title_set = frozenset(chain_titles)
                    # if len(title_set) == 0 or title_set in chain_seen:
                    #     continue
                    # chain_seen.add(title_set)

                    if sp_titles:
                        label = int(set(chain_titles) == sp_titles)
                    else:
                        label = -1
                    self.data.append({
                        "question": item["question"],
                        "passages": chain,
                        "label": label,
                        "qid": item["_id"],
                        "gold_answer": gold_answer,
                        "sp_gold": sp_gold
                    })

        print(f"Data size {len(self.data)}")