chatlearn/runtime/evaluator.py (78 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Evaluator"""
from collections import defaultdict
import ray
from chatlearn.runtime.environment import Environment
from chatlearn.utils import future
from chatlearn.utils.logger import logger
from chatlearn.utils.utils import map_reduce_metrics
from chatlearn.data.ranking import batch_generation_ranking
# pylint: disable=not-callable
class Evaluator(Environment):
"""
Evaluator.
Args
----
models : [BaseModule]
models to evaluate
args : RuntimeConfig
default to None
"""
def __init__(self, model_flow):
super().__init__(model_flow)
self.is_eval = True
self._metric_prefix = "eval"
self._metric_list = []
@property
def sample_per_episode(self):
return sum(len(dataset) for dataset in self._all_datasets)
def setup_dataset(self):
assert len(self._all_datasets) > 0, "dataset is not set"
for i, dataset in enumerate(self._all_datasets):
assert len(dataset) > 0, f"dataset {i} is not set"
if self.models[0].module_args.batch_generation.ranking:
logger.info("calling batch_generation_ranking")
for i, dataset in enumerate(self._all_datasets):
self._all_datasets[i] = batch_generation_ranking(dataset, 1, len(dataset))
refs = []
for idx, model_replica in enumerate(self.models[0].replicas):
if self.first_model.use_vllm_backend:
remainder = self.sample_per_episode % self.models[0].num_replica
batch_size_plus = 1 if idx < remainder else 0
batch_size = self.batch_size() + batch_size_plus
else:
batch_size = self.batch_size()
if batch_size > 0:
ref = model_replica.master._build_dataloader.remote(
self._all_datasets, self.sample_per_episode, is_eval=True)
refs.append(ref)
future.get(refs)
def get_all_merged_data_list(self, queues, encode=True):
queue0 = queues[0]
merged_data_list = []
while queue0.qsize() > 0:
res = self.get_merged_data(queues, encode)
merged_data_list.append(res)
return merged_data_list
def eval(self, cur_iter=None, train_iteration=None):
"""
Evaluating.
Args
----
cur_iter : int
current iteration.
train_iteration: int
current training iteration.
"""
refs = []
for model in self.models[0].replicas:
refs.append(model.master.reset_eval_data_iter.remote())
future.get(refs)
out_queue = self.execute(is_eval=True)
queue_size = out_queue.qsize()
result_refs = [out_queue.get() for _ in range(queue_size)]
element_size = len(result_refs[0])
if isinstance(result_refs[0][0], ray.ObjectRef):
data_list = future.wait(result_refs, desc="evaluator", return_output=True)
else:
data_list = result_refs[0] # List[Dict]
results = [data_list[i:i + element_size] for i in range(0, len(data_list), element_size)]
all_results = defaultdict(list)
for batches in results:
for i, batch in enumerate(batches):
model_name = self.model_flow.return_model_nodes[i].name
all_results[model_name].append(batch)
eval_info = {}
if cur_iter is not None:
eval_info["episode_iteration"] = cur_iter
if train_iteration is not None:
eval_info["train_iteration"] = train_iteration
processed_results = self.post_process(all_results, eval_info)
return processed_results
def post_process(self, results, eval_info): # pylint: disable=unused-argument
"""
Default post-process function for model evaluation results.
Args
----
results: list[]
a list of evaluation results
eval_info: dict[]
a meta that contains "train_iteration" and "episode_iteration"
"""
return results
def get_and_clear_metrics(self):
if self._metric_list is None or len(self._metric_list) == 0:
return self._metric_prefix, {}
reduced_metrics = map_reduce_metrics(self._metric_list)
self._metric_list = []
return self._metric_prefix, reduced_metrics
# pylint: disable=not-callable