in dataset-construction/src/ndb_data/generation/question_to_db.py [0:0]
def generate_answers(question_text, question_type, question_facts):
assert all(q["qid"].startswith(question_type) for q in question_facts)
answer_keys = dict()
if question_type == "argmin":
# get all keys and find the key with the smallest count
answers = defaultdict(list)
numeric_answers = False
for question in question_facts:
if question["symmetric"]:
ak = None
question["generated"]["derivation"] = question["generated"][
"derivation"
].replace("[SEP]", "[SYM]")
k, v = question["generated"]["derivation"].split("[SYM]")
else:
k, v, ak = maybe_split(question["generated"]["derivation"])
assert "[SEP]" not in k
assert "[SEP]" not in v
if try_numeric(v):
numeric_answers = True
answers[k.strip()].append((convert_comparable(v.strip()), question))
answer_keys[k.strip()] = ak
if question["symmetric"]:
answers[v.strip()].append((convert_comparable(k.strip()), question))
if not len(answers):
return [None]
else:
if not numeric_answers:
best = sorted(answers.items(), key=lambda a: len(a[1]), reverse=False)
lowest = len(best[0][1])
best_answers = {k: v for k, v in best if len(v) == lowest}
else:
if len(set(type(a[1][0][0]) for a in answers.items())) > 1:
print(question)
print([a[1][0][0] for a in answers.items()])
best = sorted(answers.items(), key=lambda a: a[1][0][0], reverse=False)
lowest = best[0][1][0]
best_answers = {k: v for k, v in best if v[0] == lowest}
return list(
[
answer_keys[k]
if k in answer_keys and answer_keys[k] is not None
else k
for k in best_answers.keys()
]
)
elif question_type == "argmax":
# get all keys and find the key with the smallest count
answers = defaultdict(list)
numeric_answers = False
for question in question_facts:
if question["symmetric"]:
ak = None
question["generated"]["derivation"] = question["generated"][
"derivation"
].replace("[SEP]", "[SYM]")
k, v = question["generated"]["derivation"].split("[SYM]")
else:
k, v, ak = maybe_split(question["generated"]["derivation"])
assert "[SEP]" not in k
assert "[SEP]" not in v
if try_numeric(v):
numeric_answers = True
answers[k.strip()].append((convert_comparable(v.strip()), question))
answer_keys[k.strip()] = ak
if question["symmetric"]:
answers[v.strip()].append((convert_comparable(k.strip()), question))
if not len(answers):
return [None]
else:
if not numeric_answers:
best = sorted(answers.items(), key=lambda a: len(a[1]), reverse=True)
highest = len(best[0][1])
best_answers = {k: v for k, v in best if len(v) == highest}
else:
if len(set(type(a[1][0][0]) for a in answers.items())) > 1:
print(question)
print([a[1][0][0] for a in answers.items()])
best = sorted(answers.items(), key=lambda a: a[1][0][0], reverse=True)
highest = best[0][1][0]
best_answers = {k: v for k, v in best if v[0] == highest}
return list(
[
answer_keys[k]
if k in answer_keys and answer_keys[k] is not None
else k
for k in best_answers.keys()
]
)
elif question_type == "min":
# get all keys and find the key with the smallest count
answers = []
for question in question_facts:
v = question["generated"]["derivation"]
if "[SEP]" in v:
v = v.split("[SEP]", maxsplit=1)[1].strip()
assert "[SEP]" not in v, v
assert try_numeric(v)
answers.append(convert_comparable(v))
best = np.min(answers) if len(answers) else None
return [best]
elif question_type == "max":
# get all keys and find the key with the smallest count
answers = []
for question in question_facts:
v = question["generated"]["derivation"]
if "[SEP]" in v:
v = v.split("[SEP]", maxsplit=1)[1].strip()
assert "[SEP]" not in v, v
answers.append(convert_comparable(v))
best = np.max(answers) if len(answers) else None
return [best]
elif question_type == "count":
# get all keys and find the key with the smallest count
answers = set()
for question in question_facts:
v = question["generated"]["derivation"]
assert "[SEP]" not in v, v
answers.add(convert_comparable(v))
best = len(answers)
return [best]
elif question_type == "bool":
# get all keys and find the key with the smallest count
answers = set()
for question in question_facts:
v = question["generated"]["derivation"]
assert "[SEP]" not in v, v
assert v == "TRUE" or v == "FALSE"
answers.add(convert_comparable(v))
return list(answers) if len(answers) else [None]
elif question_type == "set":
answers = set()
for question in question_facts:
v = question["generated"]["derivation"]
assert "[SEP]" not in v, v
answers.add(convert_comparable(v))
return list(answers)
raise Exception("Unknown quesiton type")