def generate_answers()

in modelling/src/neuraldb/convert_spj_to_predictions.py [0:0]


def generate_answers(question_type, question_facts):
    answer_keys = dict()
    if question_type == "argmin":
        # get all keys and find the key with the smallest count
        answers = defaultdict(list)
        numeric_answers = False
        for question in question_facts:
            if question["symmetric"]:
                ak = None
                question["generated"]["derivation"] = question["generated"][
                    "derivation"
                ].replace("[SEP]", "[SYM]")
                k, v = question["generated"]["derivation"].split("[SYM]")
            else:
                k, v, ak = maybe_split(question["generated"]["derivation"])
            assert "[SEP]" not in k
            assert "[SEP]" not in v

            if try_numeric(v):
                numeric_answers = True

            answers[k.strip()].append((convert_comparable(v.strip()), question))
            answer_keys[k.strip()] = ak

            if question["symmetric"]:
                answers[v.strip()].append((convert_comparable(k.strip()), question))

        if not len(answers):
            return [None]
        else:
            if not numeric_answers:
                best = sorted(answers.items(), key=lambda a: len(a[1]), reverse=False)
                lowest = len(best[0][1])
                best_answers = {k: v for k, v in best if len(v) == lowest}
            else:
                if len(set(type(a[1][0][0]) for a in answers.items())) > 1:
                    print(question)
                    print([a[1][0][0] for a in answers.items()])
                best = sorted(answers.items(), key=lambda a: a[1][0][0], reverse=False)
                lowest = best[0][1][0]
                best_answers = {k: v for k, v in best if v[0] == lowest}
            return list(
                [
                    answer_keys[k]
                    if k in answer_keys and answer_keys[k] is not None
                    else k
                    for k in best_answers.keys()
                ]
            )

    elif question_type == "argmax":
        # get all keys and find the key with the smallest count
        answers = defaultdict(list)

        numeric_answers = False
        for question in question_facts:
            if question["symmetric"]:
                ak = None
                question["generated"]["derivation"] = question["generated"][
                    "derivation"
                ].replace("[SEP]", "[SYM]")
                k, v = question["generated"]["derivation"].split("[SYM]")
            else:
                k, v, ak = maybe_split(question["generated"]["derivation"])
            assert "[SEP]" not in k
            assert "[SEP]" not in v

            if try_numeric(v):
                numeric_answers = True

            answers[k.strip()].append((convert_comparable(v.strip()), question))
            answer_keys[k.strip()] = ak

            if question["symmetric"]:
                answers[v.strip()].append((convert_comparable(k.strip()), question))

        if not len(answers):
            return [None]
        else:
            if not numeric_answers:
                best = sorted(answers.items(), key=lambda a: len(a[1]), reverse=True)
                highest = len(best[0][1])
                best_answers = {k: v for k, v in best if len(v) == highest}
            else:
                if len(set(type(a[1][0][0]) for a in answers.items())) > 1:
                    logger.warning(question)
                    logger.warning([a[1][0][0] for a in answers.items()])
                best = sorted(answers.items(), key=lambda a: a[1][0][0], reverse=True)
                highest = best[0][1][0]
                best_answers = {k: v for k, v in best if v[0] == highest}
            return list(
                [
                    answer_keys[k]
                    if k in answer_keys and answer_keys[k] is not None
                    else k
                    for k in best_answers.keys()
                ]
            )

    elif question_type == "min":
        # get all keys and find the key with the smallest count
        answers = []

        for question in question_facts:
            v = question["generated"]["derivation"]
            if "[SEP]" in v:
                v = v.split("[SEP]", maxsplit=1)[1].strip()

            assert "[SEP]" not in v, v

            assert try_numeric(v)
            answers.append(convert_comparable(v))

        best = np.min(answers) if len(answers) else None
        return [best]

    elif question_type == "max":
        # get all keys and find the key with the smallest count
        answers = []

        for question in question_facts:
            v = question["generated"]["derivation"]
            if "[SEP]" in v:
                v = v.split("[SEP]", maxsplit=1)[1].strip()

            assert "[SEP]" not in v, v
            answers.append(convert_comparable(v))

        best = np.max(answers) if len(answers) else None
        return [best]

    elif question_type == "count":
        # get all keys and find the key with the smallest count
        answers = set()

        for question in question_facts:
            v = question["generated"]["derivation"]
            assert "[SEP]" not in v, v
            answers.add(convert_comparable(v))

        best = len(answers)
        return [best]

    elif question_type == "bool":
        # get all keys and find the key with the smallest count
        answers = set()

        for question in question_facts:
            v = question["generated"]["derivation"]
            assert "[SEP]" not in v, v
            assert v == "TRUE" or v == "FALSE"
            answers.add(convert_comparable(v))

        return list(answers) if len(answers) else [None]

    elif question_type == "set":
        answers = set()

        for question in question_facts:
            v = question["generated"]["derivation"]
            assert "[SEP]" not in v, v
            answers.add(convert_comparable(v))

        return list(answers)

    raise Exception("Unknown quesiton type")