# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Engine"""

import os
import shutil
import time
import torch

from chatlearn.checkpoint.checkpoint_manager import CheckpointManager
from chatlearn.data.data import StreamDataset
from chatlearn.models.base_module import BaseModule
from chatlearn.runtime.dist_actor import DistVLLMActor
from chatlearn.runtime.environment import Environment
from chatlearn.runtime.evaluator import Evaluator
from chatlearn.runtime.trainer import Trainer
from chatlearn.schedule.model_manager import ModelManager
from chatlearn.schedule.resource_manager import ResourceManager
from chatlearn.schedule.metric_manager import MetricManager
from chatlearn.utils import future
from chatlearn.utils.constant import LOG_START
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.logger import logger
from chatlearn.utils.utils import get_full_proc_memory_info
from chatlearn.utils.timer import Timers
from chatlearn.utils.utils import map_reduce_metrics


class BaseEngine:
    """Base Engine"""

    def __init__(self, *models):
        self._models = models
        self.global_args = get_args()
        self.runtime_args = self.global_args.runtime_args
        self._timers = Timers()
        self.writer_dict = {}

    def set_timers(self, _timers):
        self._timers = _timers

    @property
    def timers(self):
        return self._timers

    def timer_summary(self):
        """
        :meta private:
        """
        if self._timers:
            return self._timers.log(reset=False, return_dict=True)

    def _create_remote_models(self):
        resource_manager = ResourceManager(self._models)
        self.model_manager = ModelManager(self._models, resource_manager, self.global_args)
        self.model_manager.remote()
        self.remote_models = self.model_manager.dist_models
        self.named_models = {model.name: model for model in self.remote_models}

    def _create_metric_manager(self):
        self.metric_manager = MetricManager(self.global_args)

    def setup(self):
        """
        :meta private:
        """
        logger.info(f"{LOG_START} setup, start to create_remote_models")
        self._create_metric_manager()
        t1 = time.time()
        self._create_remote_models()
        t2 = time.time()
        logger.info(f"{LOG_START} setup, finished to create_remote_models(s):{(t2-t1)}")
        # for ease to access model by self.{model_name}
        for model in self.remote_models:
            setattr(self, model.name, model)

        # include compile in init, compile dependencies need to be called serially
        logger.info(get_full_proc_memory_info(f"{LOG_START} Before model init"))
        for model in self.remote_models:
            model.init()
        logger.info(get_full_proc_memory_info(f"{LOG_START} After model init"))
        # do not include compile dependencies in setup
        # if the program hang in setup, may try to set concurrent_setup to False.
        self.timers("setup_models").start()
        if self.runtime_args.concurrent_setup:
            refs = []
            refs_val = []
            for model in self.remote_models:
                refs += model.model_setup()
                refs_val += model.validate()
            future.wait(refs)
            future.wait(refs_val)
        else:
            for model in self.remote_models:
                logger.info(f"{LOG_START} start setup and validate {model.name}")
                future.wait(model.model_setup())
                future.wait(model.validate())
                logger.info(f"{LOG_START} done setup and validate {model.name}")
        self.timers("setup_models").stop()
        logger.info(
            f"{LOG_START} setup_models summary {self.timers.log(names=['setup_models'])}")

    def before_episode(self):
        for model in self.remote_models:
            future.get(model.before_episode())

    def after_episode(self):
        for model in self.remote_models:
            future.get(model.after_episode())

    @property
    def models(self):
        return self.remote_models

    def get_model(self, name):
        return self.named_models[name]

    def logging_memory(self):
        def flatten(xs):
            for x in xs:
                if isinstance(x, list):
                    yield from flatten(x)
                else:
                    yield x

        refs = []
        for model in self.remote_models:
            mem_ref = model.peak_memory()
            refs.append(mem_ref)
        summaries = future.get(refs)

        logger.debug(f"{LOG_START} memory summary:")
        for model, summary in zip(self.remote_models, summaries):
            mem_str = ' | '.join(['{:.2f}'.format(i) for i in flatten(summary)])
            mem_log = f"peak_mem(GiB): {mem_str}"
            logger.debug(f"{LOG_START} {model.name} {mem_log}")

    def logging_summary(self, iteration=-1):
        _, e2e_time_dict = self.timer_summary()
        refs = []
        for model in self.remote_models:
            time_ref = model.replicas[0].timer_summary(e2e_cost=e2e_time_dict.get(model.name, None))
            refs.append(time_ref)
        summaries = future.get(refs)
        for key, value in e2e_time_dict.items():
            e2e_time_dict[key] = {'e2e': value}

        logger.info(f"{LOG_START} episode iteration {iteration + 1} time summary for each model as follows:")
        for model, summary in zip(self.remote_models, summaries):
            summary_str, summary_dict = summary[-1] if isinstance(summary, list) else summary
            logger.info(f"{LOG_START} [{model.name}] {summary_str}")
            if model.name not in e2e_time_dict:
                e2e_time_dict[model.name] = {}
            e2e_time_dict[model.name].update(summary_dict)

        self.logging_memory()
        return e2e_time_dict

    def stop(self):
        self.metric_manager.stop()
        self.model_manager.clean()


class Engine(BaseEngine):
    """Engine"""

    def __init__(self, environment=None, trainer=None, evaluator=None, name='alignment'):
        """
        Engine.

        Args
        ----
        environment : Environment
        trainer : Trainer
        evaluator: Evaluator
        """
        models = []
        for executor in [environment, trainer, evaluator]:
            if executor:
                for model in executor.models:
                    if model not in models:
                        models.append(model)
        super().__init__(*models)
        if environment:
            environment.set_timers(self.timers)
        if trainer:
            trainer.set_timers(self.timers)
        self.env = environment
        self.trainer = trainer
        self.evaluator = evaluator
        self._start_episode = 0
        self._all_datasets = None
        self._post_process = None
        self._drop_last = False
        self._wrap_data = True
        self._relay_sample_manager = None
        self._data_loader = None
        self._param_sync_pairs = []
        self._name = name

    def set_parameter_sync(self, src_model, dst_model):
        """
        sync model parameter from src_model to dst_model

        Args
        ----
        src_model: BaseModule
            src model to sync parameters
        dst_model: BaseModule
            destination model to sync parameters
        """
        self._param_sync_pairs.append((src_model, dst_model))
        return self

    def _create_remote_models(self):
        """
        :meta private:
        """
        logger.info(f"{LOG_START} create_remote_models, start to create resource_manager")
        t1 = time.time()
        resource_manager = ResourceManager(self._models)
        t2 = time.time()
        logger.info(f"{LOG_START} create_remote_models, finished to create resource_manager(s):{(t2-t1)}")
        self.model_manager = ModelManager(self._models, resource_manager, self.global_args)
        for src_model, dst_model in self._param_sync_pairs:
            self.model_manager.set_parameter_sync(src_model, dst_model)
        self.model_manager.remote()
        t3 = time.time()
        logger.info(f"{LOG_START} create_remote_models, finished to set_parameter_sync(s):{(t3-t2)}")
        self.remote_models = self.model_manager.dist_models
        self.named_models = {model.name: model for model in self.remote_models}
        t4 = time.time()
        logger.info(f"{LOG_START} create_remote_models, finished to get named_models(s):{(t4-t3)}")

    def setup(self):
        """
        :meta private:
        """
        super().setup()
        self._executors = [self.env, self.trainer, self.evaluator]
        for executor in self._executors:
            if executor:
                executor.update_models(self.remote_models)
        if self.env:
            self.env.set_multiple_datasets(self._all_datasets)
        self.timers("build_sync_paramter_groups").start()
        self.model_manager.build_parameter_group()
        self.timers("build_sync_paramter_groups").stop()
        logger.info(
            f"{LOG_START} {self._name} build_sync_paramter_groups summary {self.timers.log(names=['build_sync_paramter_groups'])}")
        self.model_manager.start_error_monitor()

    def set_dataset(self, dataset):
        """
        Set prompt dataset.

        Args
        ----
        dataset : list[str]
            a list of prompt string
        """
        assert isinstance(dataset, list), (
            f"expect datasets to be a list, got {type(dataset)}"
        )
        assert not isinstance(dataset[0], list), (
            "expect only one dataset to be set, if you want to use more "
            "than one dataset, please try `set_multiple_datasets`"
        )
        self._all_datasets = [dataset]
        return self

    def set_multiple_datasets(self, all_datasets):
        """
        Set multiple prompt datasets.

        Args
        ----
        all_datasets : list[list[str]]
            a list of lists of prompt string
        """
        # sanity check
        assert len(all_datasets) >= 1, (
            f"expect at least one dataset, got {len(all_datasets)} datasets."
        )
        assert isinstance(all_datasets, list), (
            f"expect datasets to be a list, got {type(all_datasets)}"
        )
        for dataset in all_datasets:
            assert isinstance(dataset, list), (
                f"expect each dataset to be a list of prompts, got {type(dataset)}"
            )

        self._all_datasets = all_datasets
        return self

    def set_trainer(self, trainer):
        self.trainer = trainer
        return self

    def set_environment(self, env):
        self.env = env
        return self

    def set_evaluator(self, evaluator):
        self.evaluator = evaluator
        return self

    def logging_summary(self, iteration=-1):
        """
        :meta private:
        """
        ## 1. model e2e time
        e2e_time_dict = super().logging_summary(iteration)
        # flatten time to name/<e2e or forward_step or eval_step and so on>
        model_time_dict = {}
        for model in self.remote_models:
            model_e2e_time_dict = e2e_time_dict.get(model.name, {})
            for key, value in model_e2e_time_dict.items():
                model_time_dict[f"{model.name}/{key}"] = value

        ## 2. episode time
        timer_names = ['sync_parameters',]
        # timer_names before episode looping
        if iteration == -1 and self.evaluator and self.runtime_args.enable_eval_before_training:
            timer_names.append('evaluate')
        # timer_names in episode looping
        elif iteration >= 0:
            timer_names.extend(['episode','train',])
            if self.runtime_args.save_episode_interval and \
                    (iteration + 1) % self.runtime_args.save_episode_interval == 0:
                timer_names.append('save_checkpoint')
            if self.evaluator is not None and \
                    self.runtime_args.eval_episode_interval and \
                    (iteration + 1) % self.runtime_args.eval_episode_interval == 0:
                timer_names.append('evaluate')

        episode_str, episode_metrics = self.timers.log(names=timer_names, return_dict=True)

        log_str = f"{LOG_START} {self._name} episode summary, episode {iteration + 1} {episode_str}"
        logger.info(log_str)

        ## 3. log model e2e time and episode time
        episode_metrics.update(model_time_dict)
        self.metric_manager.log("engine/timer_summary", iteration + 1, episode_metrics)

        ## 4. log before episode looping
        if iteration == -1:
            if self.evaluator and self.runtime_args.enable_eval_before_training:
                prefix, evaluate_metrics = self.evaluator.get_and_clear_metrics()
                self.metric_manager.log(prefix, iteration + 1, evaluate_metrics)
            return

        ## 5. log in episode looping
        # Train metrics
        for model in self.remote_models:
            # all_metric_tuples is like
            # [rank n-1, rank 2n-1, ...]
            # each rank refers to a tuple like (prefix, metric)
            # example1 [[('vllm_inference', {'prompt_token_length': 108.5})], [('vllm_inference', {'prompt_token_length': 121.75})]]
            # example2 [('', {})]
            # example3  [('', {'train_reward_score': 0.78125}), ('', {'train_reward_score': 0.625})]

            all_metric_tuples = future.get(model.get_and_clear_metrics())
            if isinstance(all_metric_tuples[0], list):
                all_metric_tuples_flaten = []
                for item in all_metric_tuples:
                    all_metric_tuples_flaten += item
                all_metric_tuples = all_metric_tuples_flaten
            prefix = all_metric_tuples[0][0]
            last_rank_metrics = [metric_tuple[1] for metric_tuple in all_metric_tuples]
            model_metrics = map_reduce_metrics(last_rank_metrics)
            self.metric_manager.log(prefix, iteration + 1, model_metrics)
        # Reward metrics
        if self._data_loader:
            prefix, train_reward_metrics = future.get(self._data_loader.get_and_clear_metrics.remote())
            self.metric_manager.log(prefix, iteration + 1, train_reward_metrics)
        # Evaluate metrics
        if self.evaluator:
            prefix, evaluate_metrics = self.evaluator.get_and_clear_metrics()
            self.metric_manager.log(prefix, iteration + 1, evaluate_metrics)


    def set_relay_sample_manager(self, relay_sample_manager):
        """
        Set custom relay_sample_manager.

        Args
        ----
            relay_sample_manager: inputs List[EpisodeRelayBuffer], return a list of dict.
        """
        self._relay_sample_manager = relay_sample_manager

    def learn(self):
        self.timers("chatlearn").start()
        self.timers("setup").start()
        self.setup()
        self.timers("executor_setup").start()
        for executor in self._executors:
            if executor:
                executor.setup()
        self.timers("executor_setup").stop()
        logger.info(
            f"{LOG_START} {self._name} setup executors: {self.timers.log(names=['executor_setup'])}")
        self.timers("setup").stop()
        logger.info(
            f"{LOG_START} {self._name} setup summary {self.timers.log(names=['setup'])}")
        self.logging_memory()
        self._resume_from_data_checkpoint()
        # Enable chunkflow optimization
        enable_chunkflow_optimization = os.environ.get("ENABLE_CHUNKFLOW_OPTIMIZATION", "False") in ["True", "true", "1", 1]
        logger.info(f"{LOG_START} Check ENABLE_CHUNKFLOW_OPTIMIZATION={enable_chunkflow_optimization} for chunkflow optimization")
        data_loader = StreamDataset.remote(
            self.runtime_args.stream_data_loader_type,
            self.runtime_args.train_micro_batch_size,
            self.env._padding_config,
            self.runtime_args.max_relay_episode,
            self.runtime_args.relay_episode_offset,
            self.runtime_args.train_global_batch_size \
                if enable_chunkflow_optimization \
                else self.runtime_args.train_micro_batch_size
        )

        logger.info(f"{LOG_START} " + get_full_proc_memory_info('Before first param sync'))
        dump_root_path = os.getenv("DEBUG_SYNC_PARAMETERS_PATH", "")
        if dump_root_path:
            if os.path.exists(dump_root_path):
                shutil.rmtree(dump_root_path)
            logger.info(f"{LOG_START} dump parameters before syncnizing...")
            self.dump_parameters(os.path.join(dump_root_path, "before_sync_parameter"))
        self.timers("sync_parameters").start()
        if os.getenv("ENABLE_PARAM_SYNC_WARMUP", "false") == "true":
            self.timers("warmup_sync_parameters").start()
            self.model_manager.sync_parameters(requires_grad=False, validate=False, dryrun=True)
            self.model_manager.warmup_collective_topology()
            self.timers("warmup_sync_parameters").stop()
            logger.info(f"{LOG_START} finish warmup_sync_parameters {self.timers.log(names=['warmup_sync_parameters'])} ")
        self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync)
        self.timers("sync_parameters").stop()
        if self.runtime_args.enable_eval_before_training:
            self.evaluate(-1)
        if dump_root_path:
            logger.info(f"{LOG_START} dump parameters after synchronizing...")
            self.dump_parameters(os.path.join(dump_root_path, "after_sync_parameter"))
            logger.info(f"{LOG_START} finish dump parameters, ChatLearn will exit")
            return
        logger.info(get_full_proc_memory_info('After first param sync'))
        self.logging_summary(-1)

        self._data_loader = data_loader
        for episode_id in range(self._start_episode, self.runtime_args.num_episode):
            if self.runtime_args.nsys:
                if episode_id == 4:
                    torch.cuda.cudart().cudaProfilerStart()
                if episode_id == 5:
                    torch.cuda.cudart().cudaProfilerStop()
            self.timers("episode").start()
            self.before_episode()
            logger.info(f"{LOG_START} start train episode_id: {episode_id + 1}/{self.runtime_args.num_episode}")
            if self.env.timers is None:
                self.env.set_timers(self.timers)
            queue = []
            if os.getenv("SKIP_GENERATION", None) is None:
                logger.info(f"{LOG_START} start to make experience: {episode_id + 1}/{self.runtime_args.num_episode}")
                queue = self.env.make_experiences()
                logger.info(f"{LOG_START} complete to make experience: {episode_id + 1}/{self.runtime_args.num_episode}")
                self.timers("set_train_dataset").start()
            else:
                logger.info(f"{LOG_START} Skip generation phase for episode_id: {episode_id + 1}/{self.runtime_args.num_episode}")
            refs = data_loader.set_dataset.remote(queue, episode_id, self._relay_sample_manager,
                                                  self.runtime_args.sample_per_episode)
            future.wait(refs, return_output=True)
            if self.trainer is not None:
                # validate parameter sync in the first two episodes
                validate = self.runtime_args.validate_param_sync and episode_id < 2
                self.timers("set_train_dataset").stop()
                self.trainer.set_data_loader(data_loader)
                logger.info(f"{LOG_START} set dataloader for trainer done")
                logger.info(get_full_proc_memory_info(f"{LOG_START} Before train {episode_id}"))
                if self.trainer.timers is None:
                    self.trainer.set_timers(self.timers)
                self.trainer.train(episode_id)
                logger.info(get_full_proc_memory_info(f"{LOG_START} After train {episode_id}"))
                self.timers("sync_parameters").start()
                self.model_manager.sync_parameters(episode_id + 1, validate=validate)
                self.timers("sync_parameters").stop()
                logger.info(f"{LOG_START} train episode_id: {episode_id + 1}/{self.runtime_args.num_episode} parameter sync done")
            logger.info(f"{LOG_START} train episode_id: {episode_id + 1}/{self.runtime_args.num_episode} done")
            self.timers("episode").stop()
            self.save_checkpoint(episode_id)
            self.evaluate(episode_id)
            self.after_episode()
            self.logging_summary(episode_id)

        self.timers("chatlearn").stop()
        logger.info(f"{LOG_START} {self._name} overall summary {self.timers.log(names=['chatlearn'])}")
        logger.info(f"{LOG_START} train {self._name} done")

    def _resume_from_data_checkpoint(self):
        if self.runtime_args.data_checkpoint_path:
            data_ckpt_manager = CheckpointManager(self.models[0].replicas[0], self.runtime_args.data_checkpoint_path,
                                                  self.runtime_args.max_data_ckpt_nums,
                                                  self.runtime_args.load_data_checkpoint_iteration)
            if self.runtime_args.enable_resume_training:
                meta = data_ckpt_manager.resume_meta()
                if meta:
                    self._start_episode = meta["episode"] + 1
                    self.trainer.iteration = meta["train_iteration"]
                    if self.trainer.iteration > 0:
                        logger.info(f"{LOG_START} continue train with meta {meta}")

    def dump_parameters(self, dump_path):
        for _, model in enumerate(self.models):
            replic_0 = model.replicas[0]
            if isinstance(replic_0, DistVLLMActor):
                future.wait(replic_0.vllm_engine.dump_parameters.remote(dump_path))

    def save_checkpoint(self, episode_id):
        """
        :meta private:
        """
        if self.runtime_args.save_episode_interval and \
                (episode_id + 1) % self.runtime_args.save_episode_interval == 0:
            self.timers("save_checkpoint").start()
            for model in self.trainer.models:
                refs = model.replicas[0].onload(to_onload_optimizer_states=False)
                future.wait(refs)
                refs = model.replicas[0].save_checkpoint(self.trainer.iteration)
                future.wait(refs)
                refs = model.replicas[0].offload()
                future.wait(refs)
            refs = []
            for i, model in enumerate(self.models[0].replicas):
                if isinstance(model, DistVLLMActor):
                    refs.append(model.vllm_engine.save_data_checkpoint.remote(i, self.trainer.iteration, episode_id))
                else:
                    refs.append(model.all_actors[0].save_data_checkpoint.remote(i, self.trainer.iteration, episode_id))
            future.get(refs)
            self.timers("save_checkpoint").stop()
            logger.info(f"{LOG_START} save checkpoint episode {episode_id}, train iteration {self.trainer.iteration} done")

    def evaluate(self, episode_id):
        """
        :meta private:
        """
        if self.evaluator is not None and \
                self.runtime_args.eval_episode_interval and \
                (episode_id + 1) % self.runtime_args.eval_episode_interval == 0:
            if self.evaluator.timers is None:
                self.evaluator.set_timers(self.timers)
            logger.info(f"{LOG_START} start evaluate")
            self.timers("evaluate").start()
            self.evaluator.eval(episode_id, self.trainer.iteration)
            self.timers("evaluate").stop()
            logger.info(f"{LOG_START} evaluate done")


class RLHFEngine(Engine):
    """RLHFEngine"""

    def __init__(self,
                 policy: BaseModule,
                 reference: BaseModule,
                 reward: BaseModule,
                 value: BaseModule,
                 policy_trainer: BaseModule,
                 value_trainer: BaseModule):
        def env_compute_flow(batch):
            policy_out = policy.forward_step(batch)
            ref_out = reference.forward_step(policy_out)
            value_out = value.forward_step(policy_out)
            reward_out = reward.forward_step(policy_out, ref_out, value_out)
            return value_out, reward_out

        def trainer_compute_flow(batch):
            policy_trainer.train_step(batch)
            value_trainer.train_step(batch)

        env = Environment(env_compute_flow)
        trainer = Trainer(trainer_compute_flow)
        super().__init__(env, trainer, name='rlhf')
        self.set_parameter_sync(policy_trainer, policy)
        self.set_parameter_sync(value_trainer, value)


class OnlineDPOEngine(Engine):
    """Online DPO Engine."""
    def __init__(self,
                 policy: BaseModule,
                 reference: BaseModule,
                 reward: BaseModule,
                 policy_trainer: BaseModule):
        def env_compute_flow(batch):
            policy_out = policy.forward_step(batch)
            ref_out = reference.forward_step(policy_out)
            reward_out = reward.forward_step(policy_out, ref_out)
            return reward_out

        def trainer_compute_flow(batch):
            policy_trainer.train_step(batch)

        env = Environment(env_compute_flow)
        trainer = Trainer(trainer_compute_flow)
        super().__init__(env, trainer, name='online_dpo')
        self.set_parameter_sync(policy_trainer, policy)


class DPOEngine(Engine):
    """DPO Engine."""
    def __init__(self,
                 reference: BaseModule,
                 policy_trainer: BaseModule):
        def env_compute_flow(batch):
            ref_out = reference.forward_step(batch)
            return ref_out

        def trainer_compute_flow(batch):
            policy_trainer.train_step(batch)

        env = Environment(env_compute_flow)
        trainer = Trainer(trainer_compute_flow)
        super().__init__(env, trainer, name='dpo')


class GRPOEngine(Engine):
    """GRPO Engine."""
    def __init__(self,
                 policy: BaseModule,
                 reference: BaseModule,
                 reward: BaseModule,
                 policy_trainer: BaseModule):
        def env_compute_flow(batch):
            policy_out = policy.forward_step(batch)
            ref_out = reference.forward_step(policy_out)
            reward_out = reward.forward_step(policy_out, ref_out)
            return reward_out

        def trainer_compute_flow(batch):
            policy_trainer.train_step(batch)

        env = Environment(env_compute_flow)
        trainer = Trainer(trainer_compute_flow)
        super().__init__(env, trainer, name='grpo')
        self.set_parameter_sync(policy_trainer, policy)


class GRPOMathEngine(Engine):
    """GRPO Engine with math reward"""
    def __init__(self,
                 policy,
                 reference,
                 reward,
                 reward1,
                 ppo_policy):

        def env_compute_flow(batch):
            policy_out = policy.forward_step(batch)
            ref_out = reference.forward_step(policy_out)
            reward_out = reward.forward_step(policy_out, ref_out)
            reward_out1 = reward1.forward_step(batch, policy_out)
            return reward_out, reward_out1

        def trainer_compute_flow(batch):
            ppo_policy.train_step(batch)

        def evaluator_flow(batch):
            policy_out = policy.eval_forward(batch)
            reward_out = reward.eval_forward(policy_out)
            reward_out1 = reward1.eval_forward(policy_out)
            return reward_out, reward_out1

        env = Environment(env_compute_flow)
        trainer = Trainer(trainer_compute_flow)
        evaluator = Evaluator(evaluator_flow)
        super().__init__(env, trainer, evaluator, name='grpo_math')
        self.set_parameter_sync(ppo_policy, policy)


class EvalEngine(Engine):
    """Evaluation Engine"""

    def __init__(self, eval_flow=None, evaluator=None):
        if evaluator is None:
            evaluator = Evaluator(eval_flow)
        super().__init__(evaluator=evaluator)

    def setup(self):
        super().setup()
        self.evaluator.set_multiple_datasets(self._all_datasets)
        self.evaluator.set_timers(self.timers)

    def set_dataset(self, dataset):
        """
        Set prompt dataset.

        Args
        ----
        dataset : list[str]
            a list of prompt string
        """
        assert isinstance(dataset, list), (
            f"expect datasets to be a list, got {type(dataset)}"
        )
        assert not isinstance(dataset[0], list), (
            "expect only one dataset to be set, if you want to use more "
            "than one dataset, please try `set_multiple_datasets`"
        )
        self._all_datasets = [dataset]
        return self

    def set_multiple_datasets(self, all_datasets):
        """
        Set multiple prompt datasets.

        Args
        ----
        all_datasets : list[list[str]]
            a list of lists of prompt string
        """
        # sanity check
        assert len(all_datasets) >= 1, (
            f"expect at least one dataset, got {len(all_datasets)} datasets."
        )
        assert isinstance(all_datasets, list), (
            f"expect datasets to be a list, got {type(all_datasets)}"
        )
        for dataset in all_datasets:
            assert isinstance(dataset, list), (
                f"expect each dataset to be a list of prompts, got {type(dataset)}"
            )

        self._all_datasets = all_datasets
        return self

    def eval(self, cur_iter=None, train_iteration=None):
        """
        Start evaluating.
        """
        self.setup()
        self.evaluator.setup()
        self.timers("episode").start()
        results = self.evaluator.eval(
            cur_iter=cur_iter, train_iteration=train_iteration)
        self.timers("episode").stop()
        return results
