def _prepare_metric_and_aggregation()

in lmms_eval/api/task.py [0:0]


    def _prepare_metric_and_aggregation(self):
        self._metric_fn_list = {}
        self._metric_fn_kwargs = {}
        self._aggregation_list = {}
        self._higher_is_better = {}

        if self.config.metric_list is None:
            # TODO: handle this in TaskConfig.__post_init__ ?
            _metric_list = DEFAULT_METRIC_REGISTRY[self.config.output_type]

            for metric_name in _metric_list:
                self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
                self._metric_fn_kwargs[metric_name] = {}
                self._aggregation_list[metric_name] = get_metric_aggregation(metric_name)
                self._higher_is_better[metric_name] = is_higher_better(metric_name)
        else:
            for metric_config in self.config.metric_list:
                assert "metric" in metric_config
                metric_name = metric_config["metric"]
                kwargs = {key: metric_config[key] for key in metric_config if key not in ["metric", "aggregation", "higher_is_better"]}

                if self.config.process_results is not None:
                    self._metric_fn_list[metric_name] = None
                    self._metric_fn_kwargs[metric_name] = {}
                elif callable(metric_name):
                    metric_fn = metric_name.__call__
                    metric_name = metric_name.__name__
                    self._metric_fn_list[metric_name] = metric_fn
                    self._metric_fn_kwargs[metric_name] = kwargs
                else:
                    self._metric_fn_list[metric_name] = METRIC_REGISTRY[metric_name]
                    self._metric_fn_kwargs[metric_name] = kwargs

                if "aggregation" in metric_config:
                    agg_name = metric_config["aggregation"]
                    if type(agg_name) == str:
                        self._aggregation_list[metric_name] = get_aggregation(agg_name)
                    elif callable(agg_name):
                        self._aggregation_list[metric_name] = metric_config["aggregation"]
                else:
                    INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
                    metric_agg = get_metric_aggregation(metric_name)
                    eval_logger.warning(f"[Task: {self._config.task}] metric {metric_name} is defined, but aggregation is not. " f"using default " f"aggregation={INV_AGG_REGISTRY[metric_agg]}")
                    self._aggregation_list[metric_name] = metric_agg

                if "higher_is_better" in metric_config:
                    self._higher_is_better[metric_name] = metric_config["higher_is_better"]
                else:
                    eval_logger.warning(f"[Task: {self._config.task}] metric {metric_name} is defined, but higher_is_better is not. " f"using default " f"higher_is_better={is_higher_better(metric_name)}")
                    self._higher_is_better[metric_name] = is_higher_better(metric_name)