def get_task_dict()

in lmms_eval/tasks/__init__.py [0:0]


def get_task_dict(task_name_list: List[Union[str, Dict, Task]], model_name: str):
    all_task_dict = {}

    # Ensure task_name_list is a list to simplify processing
    if not isinstance(task_name_list, list):
        task_name_list = [task_name_list]

    for task_element in task_name_list:
        if isinstance(task_element, str) and task_element in GROUP_REGISTRY:
            group_name = task_element
            for task_name in GROUP_REGISTRY[task_element]:
                if task_name not in all_task_dict:
                    # Recursively get the task dictionary for nested groups
                    task_obj = get_task_dict([task_name], model_name)
                    # Merge the dictionaries
                    all_task_dict.update({task_name: (group_name, task_obj.get(task_name, None))})
        else:
            task_name = task_element if isinstance(task_element, str) else task_element.EVAL_HARNESS_NAME
            if task_name not in all_task_dict:
                task_obj = get_task(task_name=task_name, model_name=model_name)
                all_task_dict[task_name] = task_obj

    return all_task_dict