lm_eval/tasks/belebele/_generate_configs.py (75 lines of code) (raw):

""" Take in a YAML, and output all other splits with this YAML """ import argparse import os import requests import yaml from tqdm import tqdm from lm_eval.utils import logging API_URL = "https://datasets-server.huggingface.co/splits?dataset=facebook/belebele" def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--base_yaml_path", required=True) parser.add_argument("--save_prefix_path", default="belebele") parser.add_argument("--cot_prompt_path", default=None) parser.add_argument("--task_prefix", default="") return parser.parse_args() if __name__ == "__main__": args = parse_args() # get filename of base_yaml so we can `"include": ` it in our other YAMLs. base_yaml_name = os.path.split(args.base_yaml_path)[-1] with open(args.base_yaml_path, encoding="utf-8") as f: base_yaml = yaml.full_load(f) if args.cot_prompt_path is not None: import json with open(args.cot_prompt_path, encoding="utf-8") as f: cot_file = json.load(f) def query(): response = requests.get(API_URL) return response.json()["splits"] print(query()) languages = [split["split"] for split in query()] for lang in tqdm([lang for lang in languages if "default" not in lang]): yaml_dict = { "include": base_yaml_name, "task": f"belebele_{args.task_prefix}_{lang}" if args.task_prefix != "" else f"belebele_{lang}", "test_split": lang, "fewshot_split": lang, } file_save_path = args.save_prefix_path + f"_{lang}.yaml" logging.info(f"Saving yaml for subset {lang} to {file_save_path}") with open(file_save_path, "w", encoding="utf-8") as yaml_file: yaml.dump( yaml_dict, yaml_file, width=float("inf"), allow_unicode=True, default_style='"', ) # write group config out group_yaml_dict = { "group": f"belebele_{args.task_prefix}" if args.task_prefix != "" else "belebele", "task": [ ( f"belebele_{args.task_prefix}_{lang}" if args.task_prefix != "" else f"belebele_{lang}" ) for lang in languages if "default" not in lang ], "aggregate_metric_list": [ {"metric": "acc", "aggregation": "mean", "weight_by_size": False}, {"metric": "acc_norm", "aggregation": "mean", "weight_by_size": False}, ], "metadata": {"version": 0.0}, } file_save_path = "_" + args.save_prefix_path + f"{args.task_prefix}.yaml" with open(file_save_path, "w", encoding="utf-8") as group_yaml_file: yaml.dump( group_yaml_dict, group_yaml_file, width=float("inf"), allow_unicode=True, default_style='"', )