lm_eval/tasks/afrimgsm/utils.py (194 lines of code) (raw):

import argparse import yaml languages = [ "eng", "amh", "ibo", "fra", "sna", "lin", "wol", "ewe", "lug", "xho", "kin", "twi", "zul", "orm", "yor", "hau", "sot", "swa", ] languages_REGEX = { "eng": "The answer is (\\-?[0-9\\.\\,]+)", "amh": "መልሱ (\\-?[0-9\\.\\,]+)", "ibo": "Azịza ya bụ (\\-?[0-9\\.\\,]+)", "fra": "La réponse est(\\-?[0-9\\.\\,]+)", "sna": "Mhinduro kumubvunzo ndi (\\-?[0-9\\.\\,]+)", "lin": "Eyano ezali (\\-?[0-9\\.\\,]+)", "wol": "Tontu li (\\-?[0-9\\.\\,]+)", "ewe": "ŋuɖoɖoae nye (\\-?[0-9\\.\\,]+)", "lug": "Ansa eri (\\-?[0-9\\.\\,]+)", "xho": "Impendulo ngu (\\-?[0-9\\.\\,]+)", "kin": "Igisubizo ni (\\-?[0-9\\.\\,]+)", "twi": "Ne nnyiano yɛ (\\-?[0-9\\.\\,]+)", "zul": "Impendulo ithi (\\-?[0-9\\.\\,]+)", "orm": "Deebiin isaa (\\-?[0-9\\.\\,]+)", "yor": "Ìdáhùn náà ni (\\-?[0-9\\.\\,]+)", "hau": "Amsar ita ce (\\-?[0-9\\.\\,]+)", "sot": "Karabo ke (\\-?[0-9\\.\\,]+)", "swa": "Jibu ni (\\-?[0-9\\.\\,]+)", } LANGUAGES = {} for lang in languages: if lang == "amh": LANGUAGES[lang] = { # English "QUESTION": "ጥያቄ:", "ANSWER": "በቅደም ተከተል መልስ:", "DIRECT": "Answer:", "REGEX": languages_REGEX[lang], } elif lang == "yor": LANGUAGES[lang] = { # English "QUESTION": "Ìbéèrè:", "ANSWER": "Ìdáhùn lẹ́sẹsẹ:", "DIRECT": "Answer:", "REGEX": languages_REGEX[lang], } else: LANGUAGES[lang] = { # English "QUESTION": "Question:", "ANSWER": "Step-by-Step Answer:", "DIRECT": "Answer:", "REGEX": languages_REGEX[lang], } def add_regex_pattern(regex_pattern): if regex_pattern is None: return {} return { "filter_list": [ { "name": "strict-match", "filter": [ { "function": "regex", "regex_pattern": f"""{regex_pattern}""", }, { "function": "take_first", }, ], }, { "name": "flexible-extract", "filter": [ { "function": "regex", "regex_pattern": """(-?[$0-9.,]{2,})|(-?[0-9]+)""", "group_select": -1, }, { "function": "take_first", }, ], }, ], } def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: """ Generate a yaml file for each language. :param output_dir: The directory to output the files to. :param overwrite: Whether to overwrite files if they already exist. """ err = [] for lang in LANGUAGES.keys(): try: yaml_template = "cot_yaml" filter_list = {} DELIMITER = None if mode == "direct": ANSWER = LANGUAGES["eng"]["DIRECT"] QUESTION = LANGUAGES["eng"]["QUESTION"] REGEX = None task_name = f"afrimgsm_direct_{lang}" yaml_template = "direct_yaml" if mode == "direct-native": ANSWER = LANGUAGES[lang]["DIRECT"] QUESTION = LANGUAGES[lang]["QUESTION"] REGEX = None task_name = f"afrimgsm_direct_native_{lang}" yaml_template = "direct_native_yaml" elif mode == "native-cot": ANSWER = LANGUAGES[lang]["ANSWER"] REGEX = LANGUAGES[lang]["REGEX"] QUESTION = LANGUAGES[lang]["QUESTION"] task_name = f"afrimgsm_native_cot_{lang}" filter_list = add_regex_pattern(REGEX) DELIMITER = "" if lang in ["zh", "ja"] else None elif mode == "en-cot": ANSWER = LANGUAGES["eng"]["ANSWER"] REGEX = LANGUAGES["eng"]["REGEX"] QUESTION = LANGUAGES["eng"]["QUESTION"] task_name = f"afrimgsm_en_cot_{lang}" elif mode == "translate-direct": ANSWER = LANGUAGES["eng"]["DIRECT"] QUESTION = LANGUAGES["eng"]["QUESTION"] REGEX = None task_name = f"afrimgsm_translate_direct_{lang}" yaml_template = "translate_direct_yaml" file_name = f"{task_name}.yaml" ANSWER_TO_SKIP = len(LANGUAGES[lang]["ANSWER"]) + 1 with open( f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8" ) as f: f.write("# Generated by utils.py\n") yaml.dump( { "include": yaml_template, "dataset_name": lang, "task": f"{task_name}", "doc_to_text": f"""{{% if answer is not none %}}""" f"""{{{{question+"\\n{ANSWER}"}}}}""" f"""{{% else %}}""" f"""{{{{"{QUESTION} "+question+"\\n{ANSWER}"}}}}""" f"""{{% endif %}}""", "doc_to_target": f"""{{% if answer is not none %}}""" f"""{{{{answer[{ANSWER_TO_SKIP}:]}}}}""" f"""{{% else %}}""" f"""{{{{answer_number|string}}}}""" f"""{{% endif %}}""", **filter_list, "generation_kwargs": { "until": [QUESTION, "</s>", "<|im_end|>"], "do_sample": False, }, **({"target_delimiter": DELIMITER} if DELIMITER else {}), }, f, allow_unicode=True, width=float("inf"), ) except FileExistsError: err.append(file_name) if len(err) > 0: raise FileExistsError( "Files were not created because they already exist (use --overwrite flag):" f" {', '.join(err)}" ) def main() -> None: """Parse CLI args and generate language-specific yaml files.""" parser = argparse.ArgumentParser() parser.add_argument( "--overwrite", default=False, action="store_true", help="Overwrite files if they already exist", ) parser.add_argument( "--output-dir", default=".", help="Directory to write yaml files to" ) parser.add_argument( "--mode", default="native-cot", choices=["direct", "direct-native", "native-cot", "en-cot", "translate-direct"], help="Mode of chain-of-thought", ) args = parser.parse_args() gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode) if __name__ == "__main__": main()