dataset-construction/src/ndb_data/generation/question_to_db.py [159:242]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    answer_keys = dict()
    if question_type == "argmin":
        # get all keys and find the key with the smallest count
        answers = defaultdict(list)
        numeric_answers = False
        for question in question_facts:
            if question["symmetric"]:
                ak = None
                question["generated"]["derivation"] = question["generated"][
                    "derivation"
                ].replace("[SEP]", "[SYM]")
                k, v = question["generated"]["derivation"].split("[SYM]")
            else:
                k, v, ak = maybe_split(question["generated"]["derivation"])
            assert "[SEP]" not in k
            assert "[SEP]" not in v

            if try_numeric(v):
                numeric_answers = True

            answers[k.strip()].append((convert_comparable(v.strip()), question))
            answer_keys[k.strip()] = ak

            if question["symmetric"]:
                answers[v.strip()].append((convert_comparable(k.strip()), question))

        if not len(answers):
            return [None]
        else:
            if not numeric_answers:
                best = sorted(answers.items(), key=lambda a: len(a[1]), reverse=False)
                lowest = len(best[0][1])
                best_answers = {k: v for k, v in best if len(v) == lowest}
            else:
                if len(set(type(a[1][0][0]) for a in answers.items())) > 1:
                    print(question)
                    print([a[1][0][0] for a in answers.items()])
                best = sorted(answers.items(), key=lambda a: a[1][0][0], reverse=False)
                lowest = best[0][1][0]
                best_answers = {k: v for k, v in best if v[0] == lowest}
            return list(
                [
                    answer_keys[k]
                    if k in answer_keys and answer_keys[k] is not None
                    else k
                    for k in best_answers.keys()
                ]
            )

    elif question_type == "argmax":
        # get all keys and find the key with the smallest count
        answers = defaultdict(list)

        numeric_answers = False
        for question in question_facts:
            if question["symmetric"]:
                ak = None
                question["generated"]["derivation"] = question["generated"][
                    "derivation"
                ].replace("[SEP]", "[SYM]")
                k, v = question["generated"]["derivation"].split("[SYM]")
            else:
                k, v, ak = maybe_split(question["generated"]["derivation"])
            assert "[SEP]" not in k
            assert "[SEP]" not in v

            if try_numeric(v):
                numeric_answers = True

            answers[k.strip()].append((convert_comparable(v.strip()), question))
            answer_keys[k.strip()] = ak

            if question["symmetric"]:
                answers[v.strip()].append((convert_comparable(k.strip()), question))

        if not len(answers):
            return [None]
        else:
            if not numeric_answers:
                best = sorted(answers.items(), key=lambda a: len(a[1]), reverse=True)
                highest = len(best[0][1])
                best_answers = {k: v for k, v in best if len(v) == highest}
            else:
                if len(set(type(a[1][0][0]) for a in answers.items())) > 1:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



modelling/src/neuraldb/convert_spj_to_predictions.py [122:205]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    answer_keys = dict()
    if question_type == "argmin":
        # get all keys and find the key with the smallest count
        answers = defaultdict(list)
        numeric_answers = False
        for question in question_facts:
            if question["symmetric"]:
                ak = None
                question["generated"]["derivation"] = question["generated"][
                    "derivation"
                ].replace("[SEP]", "[SYM]")
                k, v = question["generated"]["derivation"].split("[SYM]")
            else:
                k, v, ak = maybe_split(question["generated"]["derivation"])
            assert "[SEP]" not in k
            assert "[SEP]" not in v

            if try_numeric(v):
                numeric_answers = True

            answers[k.strip()].append((convert_comparable(v.strip()), question))
            answer_keys[k.strip()] = ak

            if question["symmetric"]:
                answers[v.strip()].append((convert_comparable(k.strip()), question))

        if not len(answers):
            return [None]
        else:
            if not numeric_answers:
                best = sorted(answers.items(), key=lambda a: len(a[1]), reverse=False)
                lowest = len(best[0][1])
                best_answers = {k: v for k, v in best if len(v) == lowest}
            else:
                if len(set(type(a[1][0][0]) for a in answers.items())) > 1:
                    print(question)
                    print([a[1][0][0] for a in answers.items()])
                best = sorted(answers.items(), key=lambda a: a[1][0][0], reverse=False)
                lowest = best[0][1][0]
                best_answers = {k: v for k, v in best if v[0] == lowest}
            return list(
                [
                    answer_keys[k]
                    if k in answer_keys and answer_keys[k] is not None
                    else k
                    for k in best_answers.keys()
                ]
            )

    elif question_type == "argmax":
        # get all keys and find the key with the smallest count
        answers = defaultdict(list)

        numeric_answers = False
        for question in question_facts:
            if question["symmetric"]:
                ak = None
                question["generated"]["derivation"] = question["generated"][
                    "derivation"
                ].replace("[SEP]", "[SYM]")
                k, v = question["generated"]["derivation"].split("[SYM]")
            else:
                k, v, ak = maybe_split(question["generated"]["derivation"])
            assert "[SEP]" not in k
            assert "[SEP]" not in v

            if try_numeric(v):
                numeric_answers = True

            answers[k.strip()].append((convert_comparable(v.strip()), question))
            answer_keys[k.strip()] = ak

            if question["symmetric"]:
                answers[v.strip()].append((convert_comparable(k.strip()), question))

        if not len(answers):
            return [None]
        else:
            if not numeric_answers:
                best = sorted(answers.items(), key=lambda a: len(a[1]), reverse=True)
                highest = len(best[0][1])
                best_answers = {k: v for k, v in best if len(v) == highest}
            else:
                if len(set(type(a[1][0][0]) for a in answers.items())) > 1:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



