in lmms_eval/api/task.py [0:0]
def _prepare_metric_and_aggregation(self):
self._metric_fn_list = {}
self._metric_fn_kwargs = {}
self._aggregation_list = {}
self._higher_is_better = {}
if self.config.metric_list is None:
# TODO: handle this in TaskConfig.__post_init__ ?
_metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]
for metric_name in _metric_list:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
self._metric_fn_kwargs[metric_name] = {}
self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
assert "metric" in metric_config
metric_name = metric_config["metric"]
kwargs = {key: metric_config[key] for key in metric_config if key not in ["metric", "aggregation", "higher_is_better"]}
if self.config.process_results is not None:
self._metric_fn_list[metric_name] = None
self._metric_fn_kwargs[metric_name] = {}
elif callable(metric_name):
metric_fn = metric_name.__call__
metric_name = metric_name.__name__
self._metric_fn_list[metric_name] = metric_fn
self._metric_fn_kwargs[metric_name] = kwargs
else:
self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
self._metric_fn_kwargs[metric_name] = kwargs
if "aggregation" in metric_config:
agg_name = metric_config["aggregation"]
if type(agg_name) == str:
self._aggregation_list[metric_name] = get_aggregation(agg_name)
elif callable(agg_name):
self._aggregation_list[metric_name] = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
metric_agg = get_metric_aggregation(metric_name)
eval_logger.warning(f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. " f"using default " f"aggregation={INV_AGG_REGISTRY[metric_agg]}")
self._aggregation_list[metric_name] = metric_agg
if "higher_is_better" in metric_config:
self._higher_is_better[metric_name] = metric_config["higher_is_better"]
else:
eval_logger.warning(f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. " f"using default " f"higher_is_better={is_higher_better(metric_name)}")
self._higher_is_better[metric_name] = is_higher_better(metric_name)