lm_eval/tasks/afrixnli/utils.py (216 lines of code) (raw):

import argparse import yaml class FunctionTag: def __init__(self, value): self.value = value LANGUAGES = { "amh": { "QUESTION_WORD": "ትክክል", "ENTAILMENT_LABEL": "አዎ", "NEUTRAL_LABEL": "እንዲሁም", "CONTRADICTION_LABEL": "አይ", }, "eng": { "QUESTION_WORD": "Right", "ENTAILMENT_LABEL": "Yes", "NEUTRAL_LABEL": "Also", "CONTRADICTION_LABEL": "No", }, "ewe": { "QUESTION_WORD": "Esɔ gbe", "ENTAILMENT_LABEL": "Ɛ̃", "NEUTRAL_LABEL": "Hã", "CONTRADICTION_LABEL": "Ao", }, "fra": { "QUESTION_WORD": "correct", "ENTAILMENT_LABEL": "Oui", "NEUTRAL_LABEL": "Aussi", "CONTRADICTION_LABEL": "Non", }, "hau": { "QUESTION_WORD": "Daidai", "ENTAILMENT_LABEL": "Ee", "NEUTRAL_LABEL": "Haka kuma", "CONTRADICTION_LABEL": "A'a", }, "ibo": { "QUESTION_WORD": "Ziri ezi", "ENTAILMENT_LABEL": "Éè", "NEUTRAL_LABEL": "Ọzọkwa", "CONTRADICTION_LABEL": "Mba", }, "kin": { "QUESTION_WORD": "Nibyo", "ENTAILMENT_LABEL": "Yego", "NEUTRAL_LABEL": "Na none", "CONTRADICTION_LABEL": "Oya", }, "lin": { "QUESTION_WORD": "Malamu", "ENTAILMENT_LABEL": "Iyo", "NEUTRAL_LABEL": "Lisusu", "CONTRADICTION_LABEL": "Te", }, "lug": { "QUESTION_WORD": "Kituufu", "ENTAILMENT_LABEL": "Yee", "NEUTRAL_LABEL": "N’ekirala", "CONTRADICTION_LABEL": "Nedda", }, "orm": { "QUESTION_WORD": "Sirrii", "ENTAILMENT_LABEL": "Eeyyee", "NEUTRAL_LABEL": "Akkasumas", "CONTRADICTION_LABEL": "Lakki", }, "sna": { "QUESTION_WORD": "Chokwadi", "ENTAILMENT_LABEL": "Hongu", "NEUTRAL_LABEL": "Uye", "CONTRADICTION_LABEL": "Kwete", }, "sot": { "QUESTION_WORD": "Nepile", "ENTAILMENT_LABEL": "E", "NEUTRAL_LABEL": "Hape", "CONTRADICTION_LABEL": "Tjhe", }, "swa": { "QUESTION_WORD": "Sahihi", "ENTAILMENT_LABEL": "Ndiyo", "NEUTRAL_LABEL": "Pia", "CONTRADICTION_LABEL": "Hapana", }, "twi": { "QUESTION_WORD": "Nifa", "ENTAILMENT_LABEL": "Aane", "NEUTRAL_LABEL": "Anaasɛ", "CONTRADICTION_LABEL": "Daabi", }, "wol": { "QUESTION_WORD": "Dëgg", "ENTAILMENT_LABEL": "Waaw", "NEUTRAL_LABEL": "Itam", "CONTRADICTION_LABEL": "Déet", }, "xho": { "QUESTION_WORD": "Ichanekile", "ENTAILMENT_LABEL": "Ewe", "NEUTRAL_LABEL": "Kananjalo", "CONTRADICTION_LABEL": "Hayi", }, "yor": { "QUESTION_WORD": "Òótọ́", "ENTAILMENT_LABEL": "Bẹ́ẹ̀ni", "NEUTRAL_LABEL": "Àti pé", "CONTRADICTION_LABEL": "Rárá", }, "zul": { "QUESTION_WORD": "Kulungile", "ENTAILMENT_LABEL": "Yebo", "NEUTRAL_LABEL": "Futhi", "CONTRADICTION_LABEL": "Cha", }, } def gen_lang_yamls(output_dir: str, overwrite: bool, mode: str) -> None: """ Generate a yaml file for each language. :param output_dir: The directory to output the files to. :param overwrite: Whether to overwrite files if they already exist. """ err = [] languages = [ "eng", "amh", "ibo", "fra", "sna", "wol", "ewe", "lin", "lug", "xho", "kin", "twi", "zul", "orm", "yor", "hau", "sot", "swa", ] for lang in languages: try: if mode == "native-direct": QUESTION_WORD = LANGUAGES[lang]["QUESTION_WORD"] ENTAILMENT_LABEL = LANGUAGES[lang]["ENTAILMENT_LABEL"] NEUTRAL_LABEL = LANGUAGES[lang]["NEUTRAL_LABEL"] CONTRADICTION_LABEL = LANGUAGES[lang]["CONTRADICTION_LABEL"] file_name = f"afrixnli_native_direct_{lang}.yaml" task_name = f"afrixnli_native_direct_{lang}" yaml_template = "afrixnli_native_direct_yaml" with open( f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8", ) as f: f.write("# Generated by utils.py\n") yaml.dump( { "include": yaml_template, "task": task_name, "dataset_name": lang, "doc_to_choice": f"{{{{[" f"""premise+\", {QUESTION_WORD}? {ENTAILMENT_LABEL}, \"+hypothesis,""" f"""premise+\", {QUESTION_WORD}? {NEUTRAL_LABEL}, \"+hypothesis,""" f"""premise+\", {QUESTION_WORD}? {CONTRADICTION_LABEL}, \"+hypothesis""" f"]}}}}", }, f, allow_unicode=True, ) else: file_name = f"afrixnli_{mode}_{lang}.yaml" task_name = f"afrixnli_{mode}_{lang}" yaml_template = f"afrixnli_{mode}_yaml" with open( f"{output_dir}/{file_name}", "w" if overwrite else "x", encoding="utf8", ) as f: f.write("# Generated by utils.py\n") yaml.dump( { "include": yaml_template, "task": task_name, "dataset_name": lang, }, f, allow_unicode=True, ) except FileExistsError: err.append(file_name) if len(err) > 0: raise FileExistsError( "Files were not created because they already exist (use --overwrite flag):" f" {', '.join(err)}" ) def main() -> None: """Parse CLI args and generate language-specific yaml files.""" parser = argparse.ArgumentParser() parser.add_argument( "--overwrite", default=True, action="store_true", help="Overwrite files if they already exist", ) parser.add_argument( "--output-dir", default="./manual/translate", help="Directory to write yaml files to", ) parser.add_argument( "--mode", default="manual_translate", choices=["en_direct", "native-direct", "manual_direct", "manual_translate"], help="Mode of chain-of-thought", ) args = parser.parse_args() gen_lang_yamls(output_dir=args.output_dir, overwrite=args.overwrite, mode=args.mode) if __name__ == "__main__": main()