def generate_facts_for_db()

in dataset-construction/src/ndb_data/generation/question_to_db.py [0:0]


def generate_facts_for_db(db):
    generated = {}
    generated["question_answers"] = []
    generated["question_derivations"] = []
    generated["question_facts"] = []
    generated["question_types"] = []
    generated["questions"] = []
    generated["heights"] = []
    generated["rels"] = []

    generated["indexes"] = set(f["idx"] for f in db)
    generated["qids"] = set(f["qid"] for f in db)
    generated["subjects"] = set(f["entity_ids"]["subject"] for f in db)
    generated["relations"] = set(f["entity_ids"]["relation"] for f in db)
    generated["subj_rels"] = set(
        (f["entity_ids"]["subject"], f["entity_ids"]["relation"]) for f in db
    )

    # Store the list of questions we are gonna make queries over
    logger.info("Add positive facts (with original questions)")
    active_questions = defaultdict(list)
    for query in db:
        active_questions[query["qid"]].append(query)

    # Then add negative facts (facts which have no questions)
    logger.info("Add negative facts (without questions)")
    alternative_subjects = random.sample(
        set(by_subj.keys()).difference(), db_target_size
    )
    extra_negative_facts_ids = set()
    extra_negative_facts = []
    for subj in alternative_subjects:
        additional_qs = by_subj[subj]
        extra_negative_queries = list(
            filter(
                lambda r: r.split("_")[1] not in generated["relations"],
                random.sample(additional_qs, min(len(additional_qs), 100)),
            )
        )
        for q in extra_negative_queries:
            questions = by_question[q]
            found = random.choice(questions)
            if found["idx"] not in extra_negative_facts_ids:
                extra_negative_facts.append(found)
                extra_negative_facts_ids.add(found["idx"])

    negative_facts_to_add = args.extra_negative_facts

    logger.info(f"Adding {len(extra_negative_facts)} extras")
    if len(extra_negative_facts):
        db.extend(
            random.sample(
                extra_negative_facts,
                min(len(extra_negative_facts), negative_facts_to_add),
            )
        )

    # Make an ordering of the facts that the instances will be inserted in
    logger.info("Generate random order DB")
    random.shuffle(db)
    ordering = [f["idx"] for f in db]

    resc_qids = set()

    # For the active questions, generate positive answers over the entire DB
    logger.info("Generate positive answers")
    for qid, q in active_questions.items():
        resc_qids.add(qid)
        question_texts = [question["generated"]["question"] for question in q]
        fact_ids = [ordering.index(question["idx"]) for question in q]
        question_text = random.choice(question_texts)
        question_type = q[0]["template"]["question_type"]
        positive_answers = generate_answers(question_text, question_type, q)
        generated["question_answers"].append(positive_answers)
        generated["question_derivations"].append(
            [question["generated"]["derivation"] for question in q]
        )
        generated["question_facts"].append(fact_ids)
        generated["question_types"].append(question_type)
        generated["questions"].append(question_text)
        generated["heights"].append(len(db))
        generated["rels"].append([question["entity_ids"]["relation"] for question in q])

    # For the active facts (used by these questions), make a list of indexes being used
    additional_ids = set()
    for idx in generated["indexes"]:
        questions = by_idx[idx]
        for question in questions:
            extra_facts = by_question[question]
            additional_ids.update([fact["qid"] for fact in extra_facts])

        additional_ids = additional_ids.difference(generated["qids"])

    # Generate extra questions for these facts

    logger.info("Generate bonus answers")
    for qid in additional_ids:
        extra = [a for a in by_question[qid] if a["idx"] in generated["indexes"]]
        if len(extra) and random.uniform(0, 1) < 0.2:
            resc_qids.add(qid)
            fact_ids = [ordering.index(question["idx"]) for question in extra]
            question_text = random.choice([e["generated"]["question"] for e in extra])
            positive_answers = generate_answers(
                question_text, extra[0]["template"]["question_type"], extra
            )
            #
            # if "TRUE" in positive_answers:
            #     continue

            generated["question_answers"].append(positive_answers)
            generated["question_derivations"].append(
                [question["generated"]["derivation"] for question in extra]
            )
            generated["question_facts"].append(fact_ids)
            generated["question_types"].append(extra[0]["template"]["question_type"])
            generated["questions"].append(question_text)
            generated["heights"].append(len(db) - 1)
            generated["rels"].append(
                [question["entity_ids"]["relation"] for question in extra]
            )

    # Do the same for each subset of the database
    tmp_positive_answers = []
    tmp_fact_ids = []
    tmp_derivations = []
    tmp_types = []
    tmp_questions = []
    tmp_rels = []
    tmp_heights = []

    logger.info("Generate bonus answers for smaller DBs")
    for i in range(len(ordering)):
        collected_indexes = set(ordering[:i])

        for qid in resc_qids:
            facts = by_question[qid]
            filtered_facts = [
                a
                for a in facts
                if a["idx"] in generated["indexes"] and a["idx"] in collected_indexes
            ]

            question_text = random.choice([f["generated"]["question"] for f in facts])
            question_type = facts[0]["template"]["question_type"]

            tmp_fact_ids.append(
                [ordering.index(fact["idx"]) for fact in filtered_facts]
            )
            tmp_positive_answers.append(
                generate_answers(question_text, question_type, filtered_facts)
            )
            tmp_derivations.append(
                [fact["generated"]["derivation"] for fact in filtered_facts]
            )
            tmp_heights.append(i)
            tmp_types.append(question_type)
            tmp_questions.append(question_text)
            tmp_rels.append([fact["entity_ids"]["relation"] for fact in filtered_facts])

    extended_question_answers = []
    master_answers = set(
        linearize(a) for a in zip(generated["questions"], generated["question_answers"])
    )
    for qidx, (question, answer, qtype, height) in enumerate(
        zip(tmp_questions, tmp_positive_answers, tmp_types, tmp_heights)
    ):
        sample_prob = 0.05 if linearize((question, answer)) in master_answers else 0.3
        if None in answer:
            sample_prob += 0.0
        elif qtype == "argmin" or qtype == "argmax" or qtype == "min" or qtype == "max":
            sample_prob += 0.4

        if height < 4:
            sample_prob = 0.001  # max(0.001, sample_prob - 0.2)

        if random.uniform(0, 1) <= sample_prob:
            extended_question_answers.append(qidx)

    for eq in extended_question_answers:
        generated["question_answers"].append(tmp_positive_answers[eq])
        generated["question_derivations"].append(tmp_derivations[eq])
        generated["question_facts"].append(tmp_fact_ids[eq])
        generated["question_types"].append(tmp_types[eq])
        generated["questions"].append(tmp_questions[eq])
        generated["rels"].append(tmp_rels[eq])
        generated["heights"].append(tmp_heights[eq])

    # Go through all generated questions/answers and zip them together
    generated["qs"] = []
    for question, answer, qtype, fact, derivation, height, rel in zip(
        generated["questions"],
        generated["question_answers"],
        generated["question_types"],
        generated["question_facts"],
        generated["question_derivations"],
        generated["heights"],
        generated["rels"],
    ):
        if len(fact) != len(set(fact)):
            continue
        # For all DB of facts, then create the questions and answers associated with that question
        generated["qs"].append(
            {
                "question": question.strip().replace("Whow many", "How many"),
                "answer": answer,
                "type": qtype,
                "facts": fact,
                "deriations": derivation,
                "height": height,
                "relation": rel,
            }
        )

    generated["facts"] = [q["instance"]["candidate"].strip() for q in db]
    logger.info(f"Added {len(generated['qs'])} queries to DB")
    return generated