lmms_eval/api/registry.py (97 lines of code) (raw):
from lmms_eval.api.model import lmms
import logging
eval_logger = logging.getLogger("lmms-eval")
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert issubclass(cls, lmms), f"Model '{name}' ({cls.__name__}) must extend lmms class"
assert name not in MODEL_REGISTRY, f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."
MODEL_REGISTRY[name] = cls
return cls
return decorate
def get_model(model_name):
try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}")
TASK_REGISTRY = {} # Key: task name, Value: task ConfigurableTask class
GROUP_REGISTRY = {} # Key: group name, Value: list of task names or group names
ALL_TASKS = set() # Set of all task names and group names
func2task_index = {} # Key: task ConfigurableTask class, Value: task name
def register_task(name):
def decorate(fn):
assert name not in TASK_REGISTRY, f"task named '{name}' conflicts with existing registered task!"
TASK_REGISTRY[name] = fn
ALL_TASKS.add(name)
func2task_index[fn.__name__] = name
return fn
return decorate
def register_group(name):
def decorate(fn):
func_name = func2task_index[fn.__name__]
if name in GROUP_REGISTRY:
GROUP_REGISTRY[name].append(func_name)
else:
GROUP_REGISTRY[name] = [func_name]
ALL_TASKS.add(name)
return fn
return decorate
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY = {}
HIGHER_IS_BETTER_REGISTRY = {}
DEFAULT_METRIC_REGISTRY = {
"loglikelihood": [
"perplexity",
"acc",
],
"multiple_choice": ["acc", "acc_norm"],
"generate_until": ["exact_match"],
}
def register_metric(**args):
# TODO: do we want to enforce a certain interface to registered metrics?
def decorate(fn):
assert "metric" in args
name = args["metric"]
for key, registry in [
("metric", METRIC_REGISTRY),
("higher_is_better", HIGHER_IS_BETTER_REGISTRY),
("aggregation", METRIC_AGGREGATION_REGISTRY),
]:
if key in args:
value = args[key]
assert value not in registry, f"{key} named '{value}' conflicts with existing registered {key}!"
if key == "metric":
registry[name] = fn
elif key == "aggregation":
registry[name] = AGGREGATION_REGISTRY[value]
else:
registry[name] = value
return fn
return decorate
def register_aggregation(name):
def decorate(fn):
assert name not in AGGREGATION_REGISTRY, f"aggregation named '{name}' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
def get_aggregation(name):
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} not a registered aggregation metric!".format(name),
)
def get_metric_aggregation(name):
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(
"{} metric is not assigned a default aggregation!".format(name),
)
def is_higher_better(metric_name):
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
eval_logger.warning(f"higher_is_better not specified for metric '{metric_name}'!")