chatlearn/runtime/engine.py (542 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Engine"""
import os
import shutil
import time
import torch
from chatlearn.checkpoint.checkpoint_manager import CheckpointManager
from chatlearn.data.data import StreamDataset
from chatlearn.models.base_module import BaseModule
from chatlearn.runtime.dist_actor import DistVLLMActor
from chatlearn.runtime.environment import Environment
from chatlearn.runtime.evaluator import Evaluator
from chatlearn.runtime.trainer import Trainer
from chatlearn.schedule.model_manager import ModelManager
from chatlearn.schedule.resource_manager import ResourceManager
from chatlearn.schedule.metric_manager import MetricManager
from chatlearn.utils import future
from chatlearn.utils.constant import LOG_START
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.logger import logger
from chatlearn.utils.utils import get_full_proc_memory_info
from chatlearn.utils.timer import Timers
from chatlearn.utils.utils import map_reduce_metrics
class BaseEngine:
"""Base Engine"""
def __init__(self, *models):
self._models = models
self.global_args = get_args()
self.runtime_args = self.global_args.runtime_args
self._timers = Timers()
self.writer_dict = {}
def set_timers(self, _timers):
self._timers = _timers
@property
def timers(self):
return self._timers
def timer_summary(self):
"""
:meta private:
"""
if self._timers:
return self._timers.log(reset=False, return_dict=True)
def _create_remote_models(self):
resource_manager = ResourceManager(self._models)
self.model_manager = ModelManager(self._models, resource_manager, self.global_args)
self.model_manager.remote()
self.remote_models = self.model_manager.dist_models
self.named_models = {model.name: model for model in self.remote_models}
def _create_metric_manager(self):
self.metric_manager = MetricManager(self.global_args)
def setup(self):
"""
:meta private:
"""
logger.info(f"{LOG_START} setup, start to create_remote_models")
self._create_metric_manager()
t1 = time.time()
self._create_remote_models()
t2 = time.time()
logger.info(f"{LOG_START} setup, finished to create_remote_models(s):{(t2-t1)}")
# for ease to access model by self.{model_name}
for model in self.remote_models:
setattr(self, model.name, model)
# include compile in init, compile dependencies need to be called serially
logger.info(get_full_proc_memory_info(f"{LOG_START} Before model init"))
for model in self.remote_models:
model.init()
logger.info(get_full_proc_memory_info(f"{LOG_START} After model init"))
# do not include compile dependencies in setup
# if the program hang in setup, may try to set concurrent_setup to False.
self.timers("setup_models").start()
if self.runtime_args.concurrent_setup:
refs = []
refs_val = []
for model in self.remote_models:
refs += model.model_setup()
refs_val += model.validate()
future.wait(refs)
future.wait(refs_val)
else:
for model in self.remote_models:
logger.info(f"{LOG_START} start setup and validate {model.name}")
future.wait(model.model_setup())
future.wait(model.validate())
logger.info(f"{LOG_START} done setup and validate {model.name}")
self.timers("setup_models").stop()
logger.info(
f"{LOG_START} setup_models summary {self.timers.log(names=['setup_models'])}")
def before_episode(self):
for model in self.remote_models:
future.get(model.before_episode())
def after_episode(self):
for model in self.remote_models:
future.get(model.after_episode())
@property
def models(self):
return self.remote_models
def get_model(self, name):
return self.named_models[name]
def logging_memory(self):
def flatten(xs):
for x in xs:
if isinstance(x, list):
yield from flatten(x)
else:
yield x
refs = []
for model in self.remote_models:
mem_ref = model.peak_memory()
refs.append(mem_ref)
summaries = future.get(refs)
logger.debug(f"{LOG_START} memory summary:")
for model, summary in zip(self.remote_models, summaries):
mem_str = ' | '.join(['{:.2f}'.format(i) for i in flatten(summary)])
mem_log = f"peak_mem(GiB): {mem_str}"
logger.debug(f"{LOG_START} {model.name} {mem_log}")
def logging_summary(self, iteration=-1):
_, e2e_time_dict = self.timer_summary()
refs = []
for model in self.remote_models:
time_ref = model.replicas[0].timer_summary(e2e_cost=e2e_time_dict.get(model.name, None))
refs.append(time_ref)
summaries = future.get(refs)
for key, value in e2e_time_dict.items():
e2e_time_dict[key] = {'e2e': value}
logger.info(f"{LOG_START} episode iteration {iteration + 1} time summary for each model as follows:")
for model, summary in zip(self.remote_models, summaries):
summary_str, summary_dict = summary[-1] if isinstance(summary, list) else summary
logger.info(f"{LOG_START} [{model.name}] {summary_str}")
if model.name not in e2e_time_dict:
e2e_time_dict[model.name] = {}
e2e_time_dict[model.name].update(summary_dict)
self.logging_memory()
return e2e_time_dict
def stop(self):
self.metric_manager.stop()
self.model_manager.clean()
class Engine(BaseEngine):
"""Engine"""
def __init__(self, environment=None, trainer=None, evaluator=None, name='alignment'):
"""
Engine.
Args
----
environment : Environment
trainer : Trainer
evaluator: Evaluator
"""
models = []
for executor in [environment, trainer, evaluator]:
if executor:
for model in executor.models:
if model not in models:
models.append(model)
super().__init__(*models)
if environment:
environment.set_timers(self.timers)
if trainer:
trainer.set_timers(self.timers)
self.env = environment
self.trainer = trainer
self.evaluator = evaluator
self._start_episode = 0
self._all_datasets = None
self._post_process = None
self._drop_last = False
self._wrap_data = True
self._relay_sample_manager = None
self._data_loader = None
self._param_sync_pairs = []
self._name = name
def set_parameter_sync(self, src_model, dst_model):
"""
sync model parameter from src_model to dst_model
Args
----
src_model: BaseModule
src model to sync parameters
dst_model: BaseModule
destination model to sync parameters
"""
self._param_sync_pairs.append((src_model, dst_model))
return self
def _create_remote_models(self):
"""
:meta private:
"""
logger.info(f"{LOG_START} create_remote_models, start to create resource_manager")
t1 = time.time()
resource_manager = ResourceManager(self._models)
t2 = time.time()
logger.info(f"{LOG_START} create_remote_models, finished to create resource_manager(s):{(t2-t1)}")
self.model_manager = ModelManager(self._models, resource_manager, self.global_args)
for src_model, dst_model in self._param_sync_pairs:
self.model_manager.set_parameter_sync(src_model, dst_model)
self.model_manager.remote()
t3 = time.time()
logger.info(f"{LOG_START} create_remote_models, finished to set_parameter_sync(s):{(t3-t2)}")
self.remote_models = self.model_manager.dist_models
self.named_models = {model.name: model for model in self.remote_models}
t4 = time.time()
logger.info(f"{LOG_START} create_remote_models, finished to get named_models(s):{(t4-t3)}")
def setup(self):
"""
:meta private:
"""
super().setup()
self._executors = [self.env, self.trainer, self.evaluator]
for executor in self._executors:
if executor:
executor.update_models(self.remote_models)
if self.env:
self.env.set_multiple_datasets(self._all_datasets)
self.timers("build_sync_paramter_groups").start()
self.model_manager.build_parameter_group()
self.timers("build_sync_paramter_groups").stop()
logger.info(
f"{LOG_START} {self._name} build_sync_paramter_groups summary {self.timers.log(names=['build_sync_paramter_groups'])}")
self.model_manager.start_error_monitor()
def set_dataset(self, dataset):
"""
Set prompt dataset.
Args
----
dataset : list[str]
a list of prompt string
"""
assert isinstance(dataset, list), (
f"expect datasets to be a list, got {type(dataset)}"
)
assert not isinstance(dataset[0], list), (
"expect only one dataset to be set, if you want to use more "
"than one dataset, please try `set_multiple_datasets`"
)
self._all_datasets = [dataset]
return self
def set_multiple_datasets(self, all_datasets):
"""
Set multiple prompt datasets.
Args
----
all_datasets : list[list[str]]
a list of lists of prompt string
"""
# sanity check
assert len(all_datasets) >= 1, (
f"expect at least one dataset, got {len(all_datasets)} datasets."
)
assert isinstance(all_datasets, list), (
f"expect datasets to be a list, got {type(all_datasets)}"
)
for dataset in all_datasets:
assert isinstance(dataset, list), (
f"expect each dataset to be a list of prompts, got {type(dataset)}"
)
self._all_datasets = all_datasets
return self
def set_trainer(self, trainer):
self.trainer = trainer
return self
def set_environment(self, env):
self.env = env
return self
def set_evaluator(self, evaluator):
self.evaluator = evaluator
return self
def logging_summary(self, iteration=-1):
"""
:meta private:
"""
## 1. model e2e time
e2e_time_dict = super().logging_summary(iteration)
# flatten time to name/<e2e or forward_step or eval_step and so on>
model_time_dict = {}
for model in self.remote_models:
model_e2e_time_dict = e2e_time_dict.get(model.name, {})
for key, value in model_e2e_time_dict.items():
model_time_dict[f"{model.name}/{key}"] = value
## 2. episode time
timer_names = ['sync_parameters',]
# timer_names before episode looping
if iteration == -1 and self.evaluator and self.runtime_args.enable_eval_before_training:
timer_names.append('evaluate')
# timer_names in episode looping
elif iteration >= 0:
timer_names.extend(['episode','train',])
if self.runtime_args.save_episode_interval and \
(iteration + 1) % self.runtime_args.save_episode_interval == 0:
timer_names.append('save_checkpoint')
if self.evaluator is not None and \
self.runtime_args.eval_episode_interval and \
(iteration + 1) % self.runtime_args.eval_episode_interval == 0:
timer_names.append('evaluate')
episode_str, episode_metrics = self.timers.log(names=timer_names, return_dict=True)
log_str = f"{LOG_START} {self._name} episode summary, episode {iteration + 1} {episode_str}"
logger.info(log_str)
## 3. log model e2e time and episode time
episode_metrics.update(model_time_dict)
self.metric_manager.log("engine/timer_summary", iteration + 1, episode_metrics)
## 4. log before episode looping
if iteration == -1:
if self.evaluator and self.runtime_args.enable_eval_before_training:
prefix, evaluate_metrics = self.evaluator.get_and_clear_metrics()
self.metric_manager.log(prefix, iteration + 1, evaluate_metrics)
return
## 5. log in episode looping
# Train metrics
for model in self.remote_models:
# all_metric_tuples is like
# [rank n-1, rank 2n-1, ...]
# each rank refers to a tuple like (prefix, metric)
# example1 [[('vllm_inference', {'prompt_token_length': 108.5})], [('vllm_inference', {'prompt_token_length': 121.75})]]
# example2 [('', {})]
# example3 [('', {'train_reward_score': 0.78125}), ('', {'train_reward_score': 0.625})]
all_metric_tuples = future.get(model.get_and_clear_metrics())
if isinstance(all_metric_tuples[0], list):
all_metric_tuples_flaten = []
for item in all_metric_tuples:
all_metric_tuples_flaten += item
all_metric_tuples = all_metric_tuples_flaten
prefix = all_metric_tuples[0][0]
last_rank_metrics = [metric_tuple[1] for metric_tuple in all_metric_tuples]
model_metrics = map_reduce_metrics(last_rank_metrics)
self.metric_manager.log(prefix, iteration + 1, model_metrics)
# Reward metrics
if self._data_loader:
prefix, train_reward_metrics = future.get(self._data_loader.get_and_clear_metrics.remote())
self.metric_manager.log(prefix, iteration + 1, train_reward_metrics)
# Evaluate metrics
if self.evaluator:
prefix, evaluate_metrics = self.evaluator.get_and_clear_metrics()
self.metric_manager.log(prefix, iteration + 1, evaluate_metrics)
def set_relay_sample_manager(self, relay_sample_manager):
"""
Set custom relay_sample_manager.
Args
----
relay_sample_manager: inputs List[EpisodeRelayBuffer], return a list of dict.
"""
self._relay_sample_manager = relay_sample_manager
def learn(self):
self.timers("chatlearn").start()
self.timers("setup").start()
self.setup()
self.timers("executor_setup").start()
for executor in self._executors:
if executor:
executor.setup()
self.timers("executor_setup").stop()
logger.info(
f"{LOG_START} {self._name} setup executors: {self.timers.log(names=['executor_setup'])}")
self.timers("setup").stop()
logger.info(
f"{LOG_START} {self._name} setup summary {self.timers.log(names=['setup'])}")
self.logging_memory()
self._resume_from_data_checkpoint()
# Enable chunkflow optimization
enable_chunkflow_optimization = os.environ.get("ENABLE_CHUNKFLOW_OPTIMIZATION", "False") in ["True", "true", "1", 1]
logger.info(f"{LOG_START} Check ENABLE_CHUNKFLOW_OPTIMIZATION={enable_chunkflow_optimization} for chunkflow optimization")
data_loader = StreamDataset.remote(
self.runtime_args.stream_data_loader_type,
self.runtime_args.train_micro_batch_size,
self.env._padding_config,
self.runtime_args.max_relay_episode,
self.runtime_args.relay_episode_offset,
self.runtime_args.train_global_batch_size \
if enable_chunkflow_optimization \
else self.runtime_args.train_micro_batch_size
)
logger.info(f"{LOG_START} " + get_full_proc_memory_info('Before first param sync'))
dump_root_path = os.getenv("DEBUG_SYNC_PARAMETERS_PATH", "")
if dump_root_path:
if os.path.exists(dump_root_path):
shutil.rmtree(dump_root_path)
logger.info(f"{LOG_START} dump parameters before syncnizing...")
self.dump_parameters(os.path.join(dump_root_path, "before_sync_parameter"))
self.timers("sync_parameters").start()
if os.getenv("ENABLE_PARAM_SYNC_WARMUP", "false") == "true":
self.timers("warmup_sync_parameters").start()
self.model_manager.sync_parameters(requires_grad=False, validate=False, dryrun=True)
self.model_manager.warmup_collective_topology()
self.timers("warmup_sync_parameters").stop()
logger.info(f"{LOG_START} finish warmup_sync_parameters {self.timers.log(names=['warmup_sync_parameters'])} ")
self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync)
self.timers("sync_parameters").stop()
if self.runtime_args.enable_eval_before_training:
self.evaluate(-1)
if dump_root_path:
logger.info(f"{LOG_START} dump parameters after synchronizing...")
self.dump_parameters(os.path.join(dump_root_path, "after_sync_parameter"))
logger.info(f"{LOG_START} finish dump parameters, ChatLearn will exit")
return
logger.info(get_full_proc_memory_info('After first param sync'))
self.logging_summary(-1)
self._data_loader = data_loader
for episode_id in range(self._start_episode, self.runtime_args.num_episode):
if self.runtime_args.nsys:
if episode_id == 4:
torch.cuda.cudart().cudaProfilerStart()
if episode_id == 5:
torch.cuda.cudart().cudaProfilerStop()
self.timers("episode").start()
self.before_episode()
logger.info(f"{LOG_START} start train episode_id: {episode_id + 1}/{self.runtime_args.num_episode}")
if self.env.timers is None:
self.env.set_timers(self.timers)
queue = []
if os.getenv("SKIP_GENERATION", None) is None:
logger.info(f"{LOG_START} start to make experience: {episode_id + 1}/{self.runtime_args.num_episode}")
queue = self.env.make_experiences()
logger.info(f"{LOG_START} complete to make experience: {episode_id + 1}/{self.runtime_args.num_episode}")
self.timers("set_train_dataset").start()
else:
logger.info(f"{LOG_START} Skip generation phase for episode_id: {episode_id + 1}/{self.runtime_args.num_episode}")
refs = data_loader.set_dataset.remote(queue, episode_id, self._relay_sample_manager,
self.runtime_args.sample_per_episode)
future.wait(refs, return_output=True)
if self.trainer is not None:
# validate parameter sync in the first two episodes
validate = self.runtime_args.validate_param_sync and episode_id < 2
self.timers("set_train_dataset").stop()
self.trainer.set_data_loader(data_loader)
logger.info(f"{LOG_START} set dataloader for trainer done")
logger.info(get_full_proc_memory_info(f"{LOG_START} Before train {episode_id}"))
if self.trainer.timers is None:
self.trainer.set_timers(self.timers)
self.trainer.train(episode_id)
logger.info(get_full_proc_memory_info(f"{LOG_START} After train {episode_id}"))
self.timers("sync_parameters").start()
self.model_manager.sync_parameters(episode_id + 1, validate=validate)
self.timers("sync_parameters").stop()
logger.info(f"{LOG_START} train episode_id: {episode_id + 1}/{self.runtime_args.num_episode} parameter sync done")
logger.info(f"{LOG_START} train episode_id: {episode_id + 1}/{self.runtime_args.num_episode} done")
self.timers("episode").stop()
self.save_checkpoint(episode_id)
self.evaluate(episode_id)
self.after_episode()
self.logging_summary(episode_id)
self.timers("chatlearn").stop()
logger.info(f"{LOG_START} {self._name} overall summary {self.timers.log(names=['chatlearn'])}")
logger.info(f"{LOG_START} train {self._name} done")
def _resume_from_data_checkpoint(self):
if self.runtime_args.data_checkpoint_path:
data_ckpt_manager = CheckpointManager(self.models[0].replicas[0], self.runtime_args.data_checkpoint_path,
self.runtime_args.max_data_ckpt_nums,
self.runtime_args.load_data_checkpoint_iteration)
if self.runtime_args.enable_resume_training:
meta = data_ckpt_manager.resume_meta()
if meta:
self._start_episode = meta["episode"] + 1
self.trainer.iteration = meta["train_iteration"]
if self.trainer.iteration > 0:
logger.info(f"{LOG_START} continue train with meta {meta}")
def dump_parameters(self, dump_path):
for _, model in enumerate(self.models):
replic_0 = model.replicas[0]
if isinstance(replic_0, DistVLLMActor):
future.wait(replic_0.vllm_engine.dump_parameters.remote(dump_path))
def save_checkpoint(self, episode_id):
"""
:meta private:
"""
if self.runtime_args.save_episode_interval and \
(episode_id + 1) % self.runtime_args.save_episode_interval == 0:
self.timers("save_checkpoint").start()
for model in self.trainer.models:
refs = model.replicas[0].onload(to_onload_optimizer_states=False)
future.wait(refs)
refs = model.replicas[0].save_checkpoint(self.trainer.iteration)
future.wait(refs)
refs = model.replicas[0].offload()
future.wait(refs)
refs = []
for i, model in enumerate(self.models[0].replicas):
if isinstance(model, DistVLLMActor):
refs.append(model.vllm_engine.save_data_checkpoint.remote(i, self.trainer.iteration, episode_id))
else:
refs.append(model.all_actors[0].save_data_checkpoint.remote(i, self.trainer.iteration, episode_id))
future.get(refs)
self.timers("save_checkpoint").stop()
logger.info(f"{LOG_START} save checkpoint episode {episode_id}, train iteration {self.trainer.iteration} done")
def evaluate(self, episode_id):
"""
:meta private:
"""
if self.evaluator is not None and \
self.runtime_args.eval_episode_interval and \
(episode_id + 1) % self.runtime_args.eval_episode_interval == 0:
if self.evaluator.timers is None:
self.evaluator.set_timers(self.timers)
logger.info(f"{LOG_START} start evaluate")
self.timers("evaluate").start()
self.evaluator.eval(episode_id, self.trainer.iteration)
self.timers("evaluate").stop()
logger.info(f"{LOG_START} evaluate done")
class RLHFEngine(Engine):
"""RLHFEngine"""
def __init__(self,
policy: BaseModule,
reference: BaseModule,
reward: BaseModule,
value: BaseModule,
policy_trainer: BaseModule,
value_trainer: BaseModule):
def env_compute_flow(batch):
policy_out = policy.forward_step(batch)
ref_out = reference.forward_step(policy_out)
value_out = value.forward_step(policy_out)
reward_out = reward.forward_step(policy_out, ref_out, value_out)
return value_out, reward_out
def trainer_compute_flow(batch):
policy_trainer.train_step(batch)
value_trainer.train_step(batch)
env = Environment(env_compute_flow)
trainer = Trainer(trainer_compute_flow)
super().__init__(env, trainer, name='rlhf')
self.set_parameter_sync(policy_trainer, policy)
self.set_parameter_sync(value_trainer, value)
class OnlineDPOEngine(Engine):
"""Online DPO Engine."""
def __init__(self,
policy: BaseModule,
reference: BaseModule,
reward: BaseModule,
policy_trainer: BaseModule):
def env_compute_flow(batch):
policy_out = policy.forward_step(batch)
ref_out = reference.forward_step(policy_out)
reward_out = reward.forward_step(policy_out, ref_out)
return reward_out
def trainer_compute_flow(batch):
policy_trainer.train_step(batch)
env = Environment(env_compute_flow)
trainer = Trainer(trainer_compute_flow)
super().__init__(env, trainer, name='online_dpo')
self.set_parameter_sync(policy_trainer, policy)
class DPOEngine(Engine):
"""DPO Engine."""
def __init__(self,
reference: BaseModule,
policy_trainer: BaseModule):
def env_compute_flow(batch):
ref_out = reference.forward_step(batch)
return ref_out
def trainer_compute_flow(batch):
policy_trainer.train_step(batch)
env = Environment(env_compute_flow)
trainer = Trainer(trainer_compute_flow)
super().__init__(env, trainer, name='dpo')
class GRPOEngine(Engine):
"""GRPO Engine."""
def __init__(self,
policy: BaseModule,
reference: BaseModule,
reward: BaseModule,
policy_trainer: BaseModule):
def env_compute_flow(batch):
policy_out = policy.forward_step(batch)
ref_out = reference.forward_step(policy_out)
reward_out = reward.forward_step(policy_out, ref_out)
return reward_out
def trainer_compute_flow(batch):
policy_trainer.train_step(batch)
env = Environment(env_compute_flow)
trainer = Trainer(trainer_compute_flow)
super().__init__(env, trainer, name='grpo')
self.set_parameter_sync(policy_trainer, policy)
class GRPOMathEngine(Engine):
"""GRPO Engine with math reward"""
def __init__(self,
policy,
reference,
reward,
reward1,
ppo_policy):
def env_compute_flow(batch):
policy_out = policy.forward_step(batch)
ref_out = reference.forward_step(policy_out)
reward_out = reward.forward_step(policy_out, ref_out)
reward_out1 = reward1.forward_step(batch, policy_out)
return reward_out, reward_out1
def trainer_compute_flow(batch):
ppo_policy.train_step(batch)
def evaluator_flow(batch):
policy_out = policy.eval_forward(batch)
reward_out = reward.eval_forward(policy_out)
reward_out1 = reward1.eval_forward(policy_out)
return reward_out, reward_out1
env = Environment(env_compute_flow)
trainer = Trainer(trainer_compute_flow)
evaluator = Evaluator(evaluator_flow)
super().__init__(env, trainer, evaluator, name='grpo_math')
self.set_parameter_sync(ppo_policy, policy)
class EvalEngine(Engine):
"""Evaluation Engine"""
def __init__(self, eval_flow=None, evaluator=None):
if evaluator is None:
evaluator = Evaluator(eval_flow)
super().__init__(evaluator=evaluator)
def setup(self):
super().setup()
self.evaluator.set_multiple_datasets(self._all_datasets)
self.evaluator.set_timers(self.timers)
def set_dataset(self, dataset):
"""
Set prompt dataset.
Args
----
dataset : list[str]
a list of prompt string
"""
assert isinstance(dataset, list), (
f"expect datasets to be a list, got {type(dataset)}"
)
assert not isinstance(dataset[0], list), (
"expect only one dataset to be set, if you want to use more "
"than one dataset, please try `set_multiple_datasets`"
)
self._all_datasets = [dataset]
return self
def set_multiple_datasets(self, all_datasets):
"""
Set multiple prompt datasets.
Args
----
all_datasets : list[list[str]]
a list of lists of prompt string
"""
# sanity check
assert len(all_datasets) >= 1, (
f"expect at least one dataset, got {len(all_datasets)} datasets."
)
assert isinstance(all_datasets, list), (
f"expect datasets to be a list, got {type(all_datasets)}"
)
for dataset in all_datasets:
assert isinstance(dataset, list), (
f"expect each dataset to be a list of prompts, got {type(dataset)}"
)
self._all_datasets = all_datasets
return self
def eval(self, cur_iter=None, train_iteration=None):
"""
Start evaluating.
"""
self.setup()
self.evaluator.setup()
self.timers("episode").start()
results = self.evaluator.eval(
cur_iter=cur_iter, train_iteration=train_iteration)
self.timers("episode").stop()
return results