in dataset-construction/src/ndb_data/generation/question_to_db.py [0:0]
def generate_facts_for_db(db):
generated = {}
generated["question_answers"] = []
generated["question_derivations"] = []
generated["question_facts"] = []
generated["question_types"] = []
generated["questions"] = []
generated["heights"] = []
generated["rels"] = []
generated["indexes"] = set(f["idx"] for f in db)
generated["qids"] = set(f["qid"] for f in db)
generated["subjects"] = set(f["entity_ids"]["subject"] for f in db)
generated["relations"] = set(f["entity_ids"]["relation"] for f in db)
generated["subj_rels"] = set(
(f["entity_ids"]["subject"], f["entity_ids"]["relation"]) for f in db
)
# Store the list of questions we are gonna make queries over
logger.info("Add positive facts (with original questions)")
active_questions = defaultdict(list)
for query in db:
active_questions[query["qid"]].append(query)
# Then add negative facts (facts which have no questions)
logger.info("Add negative facts (without questions)")
alternative_subjects = random.sample(
set(by_subj.keys()).difference(), db_target_size
)
extra_negative_facts_ids = set()
extra_negative_facts = []
for subj in alternative_subjects:
additional_qs = by_subj[subj]
extra_negative_queries = list(
filter(
lambda r: r.split("_")[1] not in generated["relations"],
random.sample(additional_qs, min(len(additional_qs), 100)),
)
)
for q in extra_negative_queries:
questions = by_question[q]
found = random.choice(questions)
if found["idx"] not in extra_negative_facts_ids:
extra_negative_facts.append(found)
extra_negative_facts_ids.add(found["idx"])
negative_facts_to_add = args.extra_negative_facts
logger.info(f"Adding {len(extra_negative_facts)} extras")
if len(extra_negative_facts):
db.extend(
random.sample(
extra_negative_facts,
min(len(extra_negative_facts), negative_facts_to_add),
)
)
# Make an ordering of the facts that the instances will be inserted in
logger.info("Generate random order DB")
random.shuffle(db)
ordering = [f["idx"] for f in db]
resc_qids = set()
# For the active questions, generate positive answers over the entire DB
logger.info("Generate positive answers")
for qid, q in active_questions.items():
resc_qids.add(qid)
question_texts = [question["generated"]["question"] for question in q]
fact_ids = [ordering.index(question["idx"]) for question in q]
question_text = random.choice(question_texts)
question_type = q[0]["template"]["question_type"]
positive_answers = generate_answers(question_text, question_type, q)
generated["question_answers"].append(positive_answers)
generated["question_derivations"].append(
[question["generated"]["derivation"] for question in q]
)
generated["question_facts"].append(fact_ids)
generated["question_types"].append(question_type)
generated["questions"].append(question_text)
generated["heights"].append(len(db))
generated["rels"].append([question["entity_ids"]["relation"] for question in q])
# For the active facts (used by these questions), make a list of indexes being used
additional_ids = set()
for idx in generated["indexes"]:
questions = by_idx[idx]
for question in questions:
extra_facts = by_question[question]
additional_ids.update([fact["qid"] for fact in extra_facts])
additional_ids = additional_ids.difference(generated["qids"])
# Generate extra questions for these facts
logger.info("Generate bonus answers")
for qid in additional_ids:
extra = [a for a in by_question[qid] if a["idx"] in generated["indexes"]]
if len(extra) and random.uniform(0, 1) < 0.2:
resc_qids.add(qid)
fact_ids = [ordering.index(question["idx"]) for question in extra]
question_text = random.choice([e["generated"]["question"] for e in extra])
positive_answers = generate_answers(
question_text, extra[0]["template"]["question_type"], extra
)
#
# if "TRUE" in positive_answers:
# continue
generated["question_answers"].append(positive_answers)
generated["question_derivations"].append(
[question["generated"]["derivation"] for question in extra]
)
generated["question_facts"].append(fact_ids)
generated["question_types"].append(extra[0]["template"]["question_type"])
generated["questions"].append(question_text)
generated["heights"].append(len(db) - 1)
generated["rels"].append(
[question["entity_ids"]["relation"] for question in extra]
)
# Do the same for each subset of the database
tmp_positive_answers = []
tmp_fact_ids = []
tmp_derivations = []
tmp_types = []
tmp_questions = []
tmp_rels = []
tmp_heights = []
logger.info("Generate bonus answers for smaller DBs")
for i in range(len(ordering)):
collected_indexes = set(ordering[:i])
for qid in resc_qids:
facts = by_question[qid]
filtered_facts = [
a
for a in facts
if a["idx"] in generated["indexes"] and a["idx"] in collected_indexes
]
question_text = random.choice([f["generated"]["question"] for f in facts])
question_type = facts[0]["template"]["question_type"]
tmp_fact_ids.append(
[ordering.index(fact["idx"]) for fact in filtered_facts]
)
tmp_positive_answers.append(
generate_answers(question_text, question_type, filtered_facts)
)
tmp_derivations.append(
[fact["generated"]["derivation"] for fact in filtered_facts]
)
tmp_heights.append(i)
tmp_types.append(question_type)
tmp_questions.append(question_text)
tmp_rels.append([fact["entity_ids"]["relation"] for fact in filtered_facts])
extended_question_answers = []
master_answers = set(
linearize(a) for a in zip(generated["questions"], generated["question_answers"])
)
for qidx, (question, answer, qtype, height) in enumerate(
zip(tmp_questions, tmp_positive_answers, tmp_types, tmp_heights)
):
sample_prob = 0.05 if linearize((question, answer)) in master_answers else 0.3
if None in answer:
sample_prob += 0.0
elif qtype == "argmin" or qtype == "argmax" or qtype == "min" or qtype == "max":
sample_prob += 0.4
if height < 4:
sample_prob = 0.001 # max(0.001, sample_prob - 0.2)
if random.uniform(0, 1) <= sample_prob:
extended_question_answers.append(qidx)
for eq in extended_question_answers:
generated["question_answers"].append(tmp_positive_answers[eq])
generated["question_derivations"].append(tmp_derivations[eq])
generated["question_facts"].append(tmp_fact_ids[eq])
generated["question_types"].append(tmp_types[eq])
generated["questions"].append(tmp_questions[eq])
generated["rels"].append(tmp_rels[eq])
generated["heights"].append(tmp_heights[eq])
# Go through all generated questions/answers and zip them together
generated["qs"] = []
for question, answer, qtype, fact, derivation, height, rel in zip(
generated["questions"],
generated["question_answers"],
generated["question_types"],
generated["question_facts"],
generated["question_derivations"],
generated["heights"],
generated["rels"],
):
if len(fact) != len(set(fact)):
continue
# For all DB of facts, then create the questions and answers associated with that question
generated["qs"].append(
{
"question": question.strip().replace("Whow many", "How many"),
"answer": answer,
"type": qtype,
"facts": fact,
"deriations": derivation,
"height": height,
"relation": rel,
}
)
generated["facts"] = [q["instance"]["candidate"].strip() for q in db]
logger.info(f"Added {len(generated['qs'])} queries to DB")
return generated