lmms_eval/tasks/__init__.py (107 lines of code) (raw):
import os
from typing import List, Union, Dict
from lmms_eval import utils
# from lmms_eval import prompts
from lmms_eval.api.task import TaskConfig, Task, ConfigurableTask
from lmms_eval.api.registry import (
register_task,
register_group,
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
import logging
eval_logger = logging.getLogger("lmms-eval")
def register_configurable_task(config: Dict[str, str]) -> int:
SubClass = type(
config["task"] + "ConfigurableTask",
(ConfigurableTask,),
{"CONFIG": TaskConfig(**config)},
)
if "task" in config:
task_name = "{}".format(config["task"])
register_task(task_name)(SubClass)
if "group" in config:
if config["group"] == config["task"]:
raise ValueError("task and group name cannot be the same")
elif type(config["group"]) == str:
group_name = [config["group"]]
else:
group_name = config["group"]
for group in group_name:
register_group(group)(SubClass)
return 0
def register_configurable_group(config: Dict[str, str]) -> int:
group = config["group"]
task_list = config["task"]
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if (task in TASK_REGISTRY) or (task in GROUP_REGISTRY):
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
return 0
def get_task_name_from_config(task_config: Dict[str, str]) -> str:
if "dataset_name" in task_config:
return "{dataset_path}_{dataset_name}".format(**task_config)
else:
return "{dataset_path}".format(**task_config)
def include_task_folder(task_dir: str, register_task: bool = True) -> None:
"""
Calling this function
"""
for root, subdirs, file_list in os.walk(task_dir):
# if (subdirs == [] or subdirs == ["__pycache__"]) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
yaml_path = os.path.join(root, f)
try:
config = utils.load_yaml_config(yaml_path)
if "task" not in config:
continue
if register_task:
if type(config["task"]) == str:
register_configurable_task(config)
else:
if type(config["task"]) == list:
register_configurable_group(config)
# Log this silently and show it only when
# the user defines the appropriate verbosity.
except ModuleNotFoundError as e:
eval_logger.debug(f"{yaml_path}: {e}. Config will not be added to registry.")
except Exception as error:
import traceback
eval_logger.debug(f"Failed to load config in {yaml_path}. Config will not be added to registry\n" f"Error: {error}\n" f"Traceback: {traceback.format_exc()}")
return 0
def include_path(task_dir):
include_task_folder(task_dir)
# Register Benchmarks after all tasks have been added
include_task_folder(task_dir, register_task=False)
return 0
def initialize_tasks(verbosity="INFO"):
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_path(task_dir)
def get_task(task_name, model_name):
try:
return TASK_REGISTRY[task_name](model_name=model_name)
except KeyError:
eval_logger.info("Available tasks:")
eval_logger.info(list(TASK_REGISTRY) + list(GROUP_REGISTRY))
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return task_object.EVAL_HARNESS_NAME if hasattr(task_object, "EVAL_HARNESS_NAME") else type(task_object).__name__
# TODO: pass num_fewshot and other cmdline overrides in a better way
def get_task_dict(task_name_list: List[Union[str, Dict, Task]], model_name: str):
all_task_dict = {}
# Ensure task_name_list is a list to simplify processing
if not isinstance(task_name_list, list):
task_name_list = [task_name_list]
for task_element in task_name_list:
if isinstance(task_element, str) and task_element in GROUP_REGISTRY:
group_name = task_element
for task_name in GROUP_REGISTRY[task_element]:
if task_name not in all_task_dict:
# Recursively get the task dictionary for nested groups
task_obj = get_task_dict([task_name], model_name)
# Merge the dictionaries
all_task_dict.update({task_name: (group_name, task_obj.get(task_name, None))})
else:
task_name = task_element if isinstance(task_element, str) else task_element.EVAL_HARNESS_NAME
if task_name not in all_task_dict:
task_obj = get_task(task_name=task_name, model_name=model_name)
all_task_dict[task_name] = task_obj
return all_task_dict