# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""base module"""

from collections import defaultdict
from itertools import cycle
from pathlib import Path
import math
import time
import os

import torch

import ray
import ray.util.collective as col
from ray.util.collective.collective_group.base_collective_group import BaseGroup
from ray.util.collective.collective_group.nccl_collective_group import NCCLGroup
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from chatlearn.data.sampler import MultiDatasetSampler
from chatlearn.data.data import RLHFDataLoader
from chatlearn.checkpoint.checkpoint_manager import CheckpointManager
from chatlearn.utils import future
from chatlearn.utils.constant import LOG_START
from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense
from chatlearn.utils.dist_utils import bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.global_vars import set_global_variables
from chatlearn.utils.logger import logger
from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger
from chatlearn.utils.timer import Timers
from chatlearn.utils.utils import get_host_addr, map_reduce_metrics
from chatlearn.launcher import dlc_utils


class BaseModule:
    """BaseModule is the base class for Base models.

    Args
    ----
    name : str
        model name
    """

    def __init__(self, name, args=None, replica_id=0):
        logger.info(f"{LOG_START} basemodule {name} init start")
        self.name = name
        if args is None:
            global_args = get_args()
        else:
            global_args = args
            set_global_variables(args)
        self.global_args = global_args
        args = global_args.models[name]
        self.total_gpu = args.num_gpu
        self.total_cpu = args.num_cpu
        self.gpu_per_process = args.gpu_per_process
        self.trainable = args.trainable
        self._runtime_args = self.global_args.runtime_args
        self._module_args = args
        self.replica_id = replica_id
        self.config_dir = args.config_dir
        self._is_colocate = False

        if self.total_gpu > 0:
            self._num_gpu_per_replica = (
                args.tensor_model_parallel_size
                * args.pipeline_model_parallel_size
                * args.expert_model_parallel_size
                * args.zero_size
                * args.fsdp_size
            )
            assert self._num_gpu_per_replica <= self.total_gpu, \
                f"_num_gpu_per_replica {self._num_gpu_per_replica} larger than total_gpu {self.total_gpu} " + \
                f"tp_size: {args.tensor_model_parallel_size} pp_size: {args.pipeline_model_parallel_size} " + \
                f"ep_size: {args.expert_model_parallel_size} zero_size: {args.zero_size}"
            assert self.total_gpu % self._num_gpu_per_replica == 0
            if not self.trainable:
                self._num_replica = args.num_gpu // self._num_gpu_per_replica
            else:
                # For trainable models, perform the DP inside DistActor
                self._num_replica = 1
                self._num_gpu_per_replica = self.total_gpu
        else:
            self._num_gpu_per_replica = 0
            self._num_replica = args.num_replica

        assert self._num_replica >= 1
        self._param_ranks = None
        self._named_parameters = None
        self._param_to_name = None
        self._parameters = None
        self._coalesced_parameters = None
        self.error_signal = None
        self._rank = None
        self._world_size = None
        self._group_names = []
        self._dataloader = None
        self._eval_dataloader = None
        self._kl_coef = None
        self._padding_config = {}
        self._storage = None
        self._timers = None
        self._data_iter = None
        self._eval_data_iter = None
        self.call_funcs = []
        self.trainable_funcs = []
        self._data_ckpt_manager = None
        self._peak_memory = 0
        self._parameters_to_sync = defaultdict(list)
        self._parameters_to_send = defaultdict(list)
        self._parameters_to_recv = defaultdict(list)
        self._parameters_shape = []
        # current compute iteration
        self._iteration = 0
        self._train_iteration = 0
        self._episode_id = 0
        self.enable_lora = self._module_args.lora.enable_lora
        self._finalized = False
        self._resume_training = False
        self._address = dlc_utils.get_addr() if dlc_utils.in_dlc_env() else get_host_addr()
        self._is_master_node = os.environ.get("RANK", '0') == '0'
        self._logger = setup_logger(model_name=self.name, ip_addr=self._address)
        # parameter sync from src_model
        self._src_parameter_model = None
        self.profiler = None
        self._buffer_num = {}
        self._tp_division = {}
        self._tp_num_mapping = 1
        self._sync_buffer = defaultdict(list)
        self._sync_dst_rank_to_src_ranks = {}
        self._expert_sync_buffer = {}
        self._synchronizer = None
        self._metric_prefix = ""
        self._metric_list = []
        self._stage_resume_done = False
        logger.info(f"{LOG_START} basemodule {name} init done")

    def get_sync_buffer(self):
        return self._sync_buffer

    def set_tp_num_mapping(self, _tp_num_mapping):
        self._tp_num_mapping = _tp_num_mapping

    @property
    def tp_num_mapping(self):
        return self._tp_num_mapping

    def set_buffer_num(self, buffer_num):
        self._buffer_num.update(buffer_num)

    def get_buffer_num(self, param_names):
        return [self._buffer_num[name] for name in param_names]

    def set_tp_division(self, tp_division):
        self._tp_division.update(tp_division)

    def get_tp_division(self, param_names):
        return [self._tp_division[name] for name in param_names]

    @property
    def is_colocate(self):
        return self._is_colocate

    def set_colocate(self, flag):
        self._is_colocate = flag

    def finalize(self):
        """
        finalize the class, any change from user after finalize will not work.

        :meta private:
        """
        self._finalized = True

    def _assert_not_finalized(self):
        """
        :meta private:
        """
        assert not self._finalized, f"{self} is finalized, any change to the class should happen before finalize."

    def get_runtime_args(self):
        return self.runtime_args

    @property
    def runtime_args(self):
        """
        Return the arguments related to alignment training,
        the settings that are specified under the "runtime" section of the YAML configuration file.
        """
        return self._runtime_args

    @property
    def model_args(self):
        """
        Return model arguments, such as those related to Megatron,
        should be specified in a separate configuration yaml file for the model being used.
        """
        return self._module_args.args_dict

    @property
    def module_args(self):
        """
        Return module arguments. module_args include `num_gpu`, `gpu_per_process`, `model_config_file`, etc.
        """
        return self._module_args

    @property
    def parameter_sync_frequency(self):
        return self.module_args.sync_frequency

    def set_env(self, args):
        """
        set system env, private

        :meta private:
        """

    def set_error_signal(self, error_signal):
        """
        signal for handling errors

        :meta private:
        """
        self.error_signal = error_signal

    def error(self, error_msg=None):
        """
        :meta private:
        """
        future.wait(self.error_signal.set.remote(error_msg))

    def init(self):
        """
        Init env.
        """

    def setup(self):
        """
        Create model / optimizer / opt_param_scheduler / etc.
        """

    @property
    def data_ckpt_manager(self):
        """
        :meta private:
        """
        if self.runtime_args.data_checkpoint_path is not None:
            assert self._data_ckpt_manager is not None
        return self._data_ckpt_manager

    def model_setup(self):
        """
        :meta private:
        """
        self.global_args.active_module_args = self._module_args
        if self.runtime_args.data_checkpoint_path is not None:
            self._data_ckpt_manager = CheckpointManager(self, self.runtime_args.data_checkpoint_path,
                                                       self.runtime_args.max_data_ckpt_nums,
                                                       self.runtime_args.load_data_checkpoint_iteration)
            if self.runtime_args.enable_resume_training:
                meta = self._data_ckpt_manager.resume()
                if meta:
                    self._resume_training = self.runtime_args.consumed_samples > 0
                    start_episode = meta["episode"] + 1
                    self._episode_id = start_episode
                    self._iteration = start_episode * math.ceil(self.runtime_args.sample_per_episode / \
                        self._num_replica / self.module_args.generation_batch_size)

                    log_rank_0(
                        f"{self.name} resume training {self._resume_training}: "
                        f"set start iteration to {self._iteration} and episode id to {self._episode_id}",
                        self._logger)
        self.setup()

    def forward_step(self, data, iteration):
        """
        Perform forward step for one batch.

        Args
        ----
        data : dict
            data for forward_step
        iteration : int
            local forward iteration
        
        Returns
        -------
        Dict
            A dict of results, where key is the string type, and the value is the tensor or a list,
            where the first dim of tensor or the len of list equals to batch size
        """

    def train_step(self, data, iteration):
        """
        Perform train_step for one batch, including a list of micro-batches.

        Args
        ----
        data : [Dict]
            A list of micro-batch for train_step, type of each micro-batch is dict
        iteration : int
            local train iteration
        """

    def eval_step(self, data):
        """
        Perform eval_step for one batch

        Args
        ----
            data: Dict
                Data for eval_step.

        Returns
        -------
            Dict
                A dict of results, where key is the string type, and the value is the tensor or a list,
                where the first dim of tensor or the len of list equals to batch size
        """

    def save_checkpoint(self, iteration):
        """
        Save checkpoint given iteration.

        Args
        ----
            iteration: int
                Current training iteration
        """

    def save_data_checkpoint(self, replica_id, iteration, episode_id):
        """
        Save checkpoint for dataloader.

        :meta private:
        """
        if self.data_ckpt_manager is not None:
            consumed_samples = self.runtime_args.consumed_samples
            self.data_ckpt_manager.save_checkpoint(replica_id, iteration, episode_id, consumed_samples)

    def put(self, key, data):
        """
        Put the data to shared storage.

        Args
        ----
            key: Str
                Use key to put.
            data
                data to save
        """
        self._storage.put.remote(key, data)

    def get(self, key):
        """
        Get data from shared storage using key

        Args
        ----
            key: Str
                use key to get
        """
        ref = self._storage.get.remote(key)
        return future.get(ref)

    def validate(self):
        """
        :meta private:
        """

    def before_episode(self):
        """
        Operations before one episode.
        """

    def after_episode(self):
        """
        Operations after one episode.
        """
        self._episode_id += 1

    def build_dataset(self, train_prompts, is_eval=False):
        """
        Build prompt dataset

        Args
        ----
            train_prompts: [Str]
                A list of prompt string.
        Returns
        -------
            torch.utils.data.Dataset
                Dataset with user-defined collate_fn
        """

    def build_all_dataset(self, train_prompts_list, is_eval=False):
        """
        Build all prompt datasets

        Args
        ----
            train_prompts_list: List[List[Str]]
                A list of prompt string lists.
        Returns
        -------
            List[torch.utils.data.Dataset]
                A list of Dataset with user-defined collate_fn
        """
        all_datasets = []
        for train_prompts in train_prompts_list:
            all_datasets.append(
                self.build_dataset(train_prompts, is_eval)
            )
        return all_datasets

    def _build_dataloader(self, data, sample_per_episode, is_eval=False):
        """
        build and set the dataloader for the model

        Args:
            data: a list of string
            is_eval: set to `True` to build a dataloader for evaluation (default: `False`)

        :meta private:
        """
        all_datasets = self.build_all_dataset(data, is_eval) # pylint: disable=assignment-from-no-return
        consumed_samples = 0
        data_ratio = self.runtime_args.data_ratio
        shuffle = self.runtime_args.data_shuffle
        data_rerank = self.runtime_args.data_rerank
        if not is_eval:
            if self.data_ckpt_manager is not None:
                consumed_samples = self.runtime_args.consumed_samples
        collate_fn = all_datasets[0].collate_fn if hasattr(all_datasets[0], 'collate_fn') else None
        drop_last = self.model_args['drop_last'] if 'drop_last' in self.model_args else False
        dataloader = self.build_dataloader(all_datasets,
                                           sample_per_episode=sample_per_episode,
                                           collate_fn=collate_fn,
                                           is_eval=is_eval,
                                           consumed_samples=consumed_samples,
                                           data_ratio=data_ratio,
                                           shuffle=shuffle,
                                           drop_last=drop_last,
                                           data_rerank=data_rerank)

        if is_eval:
            self._eval_dataloader = dataloader
            self._eval_data_iter = iter(self._eval_dataloader)
        else:
            self._data_iter = iter(dataloader)
            self._data_iter = cycle(self._data_iter)
            self._dataloader = dataloader

    def build_dataloader(self,
                         all_datasets,
                         sample_per_episode,
                         collate_fn=None,
                         is_eval=False,
                         consumed_samples=0,
                         data_ratio=None,
                         shuffle=True,
                         drop_last=False,
                         data_rerank=True):
        """
        build the dataloader for the model
        Args:
            all_datasets: a list of torch.utils.data.Dataset objects
            batch_size: how many samples per batch to load
            collate_fn: set when loading from an map-style dataset (defulat: `None`)
            is_eval: set to `True` to build a dataloader for evaluation (default: `False`)
            consumed_samples: consumed samples (default: `0`)
            data_ratio: ratio of samples for each dataset (default: `None`)
            drop_last: whether to drop last samples (default: `False`)

        :meta private:
        """
        log_rank_0(
            f"Creating DataLoader... consumed_samples: {consumed_samples}, "
            f"data_ratio: {data_ratio}",
            self._logger
        )
        if "num_inference_per_prompt" in self.model_args:
            num_inference_per_prompt = self.model_args["num_inference_per_prompt"]
        else:
            num_inference_per_prompt = 1
        vllm_prompt_key = self.model_args["vllm_prompt_key"] \
            if "vllm_prompt_key" in self.model_args else "prompt"
        self._logger.info(f"====Data Rerank: {data_rerank}")
        if is_eval:
            batch_sampler = MultiDatasetSampler(
                dataset_sizes=[len(dataset) for dataset in all_datasets],
                sample_per_episode=sample_per_episode,
                shuffle=False,
                is_eval=True,
                data_parallel_rank=self.replica_id,
                data_parallel_size=self._num_replica
            )
        else:
            batch_sampler = MultiDatasetSampler(
                dataset_sizes=[len(dataset) for dataset in all_datasets],
                sample_per_episode=sample_per_episode,
                data_ratio=data_ratio,
                consumed_samples=consumed_samples,
                num_inference_per_prompt=num_inference_per_prompt,
                shuffle=shuffle,
                is_eval=False,
                data_parallel_rank=self.replica_id,
                data_parallel_size=self._num_replica,
                drop_last="drop" if drop_last else "cycle",
                data_rerank=data_rerank
            )
        return RLHFDataLoader(
            all_datasets,
            batch_sampler,
            collate_fn=collate_fn,
            add_uid=True,
            data_parallel_rank=self.replica_id,
            data_parallel_size=self._num_replica,
            num_inference_per_prompt=num_inference_per_prompt,
            vllm_prompt_key=vllm_prompt_key
        )

    def reset_eval_data_iter(self):
        """
        :meta private:
        """
        if self._eval_dataloader is not None:
            self._eval_data_iter = iter(self._eval_dataloader)

    def next_batch(self, is_eval=False):
        """
        :meta private:
        """
        if is_eval:
            return next(self._eval_data_iter)
        else:
            return next(self._data_iter)

    @property
    def num_replica(self):
        """
        :meta private:
        """
        return self._num_replica

    @property
    def num_gpu_per_replica(self):
        """
        :meta private:
        """
        return self._num_gpu_per_replica

    def setup_collective_group(self, rank, world_size, backend, group_name):
        """
        :meta private:
        """
        self._group_names.append(group_name)
        self._world_size = world_size
        col.init_collective_group(
            world_size, rank, backend=backend, group_name=group_name)

    def broadcast_dummy_tensor_send(self, src_rank, group_name):
        x = torch.zeros(1, device="cuda")
        col.broadcast(x, src_rank=src_rank, group_name=group_name)
        del x

    def broadcast_dummy_tensor_recv(self, src_rank, group_name):
        x = torch.zeros(1, device="cuda")
        col.broadcast(x, src_rank=src_rank, group_name=group_name)
        del x

    def _destroy_collective_group(self, group_name):
        """
        :meta private:
        """
        from ray.util.collective.collective import _group_mgr # pylint: disable=import-outside-toplevel
        rank = col.get_rank(group_name)
        saved_group: BaseGroup = _group_mgr.get_group_by_name(group_name)
        saved_comm_keys = []
        if isinstance(saved_group, (NCCLGroup, )):
            saved_comm_keys = list(saved_group._dev_comm_map.keys())

        try:
            col.destroy_collective_group(group_name)
        except Exception as e:
            self._logger.warning(f"_destroy_collective_group {group_name} {e}")

        if isinstance(saved_group, (NCCLGroup, )):
            for comm_key in saved_comm_keys:
                group_key = saved_group._generate_group_key(comm_key)
                from ray.util.collective.const import get_store_name # pylint: disable=import-outside-toplevel
                store_name = get_store_name(group_key)
                try:
                    store = ray.get_actor(store_name)
                    if rank == 0:
                        raise RuntimeError(f'{store_name} in group {group_name} should be killed on rank {rank}.')
                    self._logger.debug(f'Kill {store_name} in group {group_name} on rank {rank}')
                    ray.kill(store)
                except ValueError:
                    ...

    def destroy_collective_group(self):
        for group_name in self._group_names:
            self._destroy_collective_group(group_name)
        self._group_names = []

    def get_local_param_ranks(self):
        """
        :meta private:
        """

    def fuse_lora_layer(self):
        """
        :meta private:
        """
        from chatlearn.models.megatron.lora import fuse_lora_layer # pylint: disable=import-outside-toplevel
        fuse_lora_layer(self.model)

    def unfuse_lora_layer(self):
        """
        :meta private:
        """
        from chatlearn.models.megatron.lora import unfuse_lora_layer # pylint: disable=import-outside-toplevel
        unfuse_lora_layer(self.model)

    @property
    def rank(self):
        """
        :meta private:
        """
        return self._rank

    def get_rank(self):
        """
        :meta private:
        """
        return self.rank

    def is_last_rank(self):
        """
        Is last rank.
        """
        return True

    @property
    def parameters(self):
        """
        :meta private:
        """
        if self._parameters is None:
            if not isinstance(self.model, list):
                model = [self.model]
            else:
                model = self.model
            self._parameters = []
            for partition in model:
                for item in partition.parameters():
                    self._parameters.append(item)
        return self._parameters

    @property
    def named_parameters(self):
        """
        :meta private:
        """
        if self._named_parameters is None:
            if not isinstance(self.model, list):
                model = [self.model]
            else:
                model = self.model
            self._named_parameters = {}
            for partition in model:
                for item in partition.named_parameters():
                    self._named_parameters[item[0]] = item[1]
        return self._named_parameters

    @property
    def param_to_name(self):
        """
        :meta private:
        """
        if self._param_to_name is None:
            if not isinstance(self.model, list):
                model = [self.model]
            else:
                model = self.model
            self._param_to_name = {}
            for partition in model:
                for item in partition.named_parameters():
                    self._param_to_name[item[1]] = item[0]
        return self._param_to_name

    def _set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None):
        if parameters_to_sync is None:
            parameters_to_sync = defaultdict(list)
        assert pipe_stage not in parameters_to_sync or len(parameters_to_sync[pipe_stage])==0
        params_to_sync_list = [(name, self.named_parameters[name]) for name in trainable_param_names]
        if self._synchronizer is not None:
            params_to_sync_list = self._synchronizer.transform_parameters(params_to_sync_list)
        parameters_to_sync[pipe_stage] = params_to_sync_list
        return parameters_to_sync

    def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None):
        """
        :meta private:
        """
        if parameters_to_sync is None:
            parameters_to_sync = self._parameters_to_sync
        if pipe_stage not in parameters_to_sync or len(parameters_to_sync[pipe_stage]) == 0:
            self._set_sync_parameters(trainable_param_names, pipe_stage, parameters_to_sync)

    def reset_sync_parameters(self, trainable_param_names, pipe_stage=0):
        self._parameters_to_sync[pipe_stage] = []
        self._set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_sync)

    def set_send_parameters(self, trainable_param_names, pipe_stage=0):
        """
        :meta private:
        """
        return self.set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_send)

    def set_recv_parameters(self, to_rank, trainable_param_names, pipe_stage=0):
        """
        :meta private:
        """
        parameters_to_recv = defaultdict(list)
        self._parameters_to_recv[to_rank] = parameters_to_recv
        return self.set_sync_parameters(trainable_param_names, pipe_stage, parameters_to_recv)

    def clear_sync_parameters(self):
        self._parameters_to_sync = defaultdict(list)

    def clear_send_recv_parameters(self):
        self._parameters_to_send = defaultdict(list)
        self._parameters_to_recv = defaultdict(list)

    def clear_sync_send_recv_parameters(self):
        self.clear_sync_parameters()
        self.clear_send_recv_parameters()

    def get_parameter_names(self, requires_grad=True):
        """
        :meta private:
        """
        param_to_name = self.param_to_name
        if requires_grad:
            return [param_to_name[param] for param in self.parameters if param.requires_grad]
        else:
            return [param_to_name[param] for param in self.parameters]

    def get_parameter_shape(self, pipe_stage=0, parameters_to_sync=None):
        """
        :meta private:
        """
        if parameters_to_sync is None:
            parameters_to_sync = self._parameters_to_sync
        parameters_shape = []
        for name, param in parameters_to_sync[pipe_stage]:
            if self._expert_sync_buffer and name in self._expert_sync_buffer and \
                    self._synchronizer and self._synchronizer.is_parameter_changed:
                parameters_shape.append((name, self._expert_sync_buffer[name].shape))
            else:
                parameters_shape.append((name, param.shape))
        return parameters_shape

    def get_parameter(self, name):
        """
        :meta private:
        """
        if name not in self.named_parameters:
            raise Exception(f"parameter {name} not exits")
        return self.named_parameters[name]

    def get_parameter_to_sync(self, name, pipe_stage, to_cpu=False, regroup=False):
        assert pipe_stage in self._parameters_to_sync and len(self._parameters_to_sync[pipe_stage]) > 0
        for name0, param in self._parameters_to_sync[pipe_stage]:
            if name0 == name:
                if name in self._expert_sync_buffer and self._synchronizer and \
                        self._synchronizer.is_parameter_changed:
                    param = self._expert_sync_buffer[name]
                    regroup_routed_experts = True
                else:
                    regroup_routed_experts = False
                if regroup and self._synchronizer:
                    param = self._synchronizer.regroup_params_to_sync(
                        name,
                        param.data,
                        self._tp_division[name],
                        regroup_routed_experts
                    )
                if to_cpu:
                    param = param.cpu()
                else:
                    param = param.cuda()
                return param

    def get_parameter_to_sync_names(self, pipe_stage):
        return [items[0] for items in self._parameters_to_sync[pipe_stage]]

    def exist_parameter(self, name):
        """
        :meta private:
        """
        return name in self.named_parameters

    def parameter_shape(self, name):
        """
        :meta private:
        """
        return self.get_parameter(name).shape

    def send_recv_parameter(self, rank, group_name, func, pipe_stage=0):
        """
        :meta private:
        """
        tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]]
        dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
        debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
        for bucket in dense_buckets:
            tensor_changed = func is col.recv
            coalesced_comm_dense(bucket, func, extra_args=(rank, group_name), tensor_changed=tensor_changed)
        for param in sparse_bucket:
            func(param, rank, group_name)

    def alltoall_routed_expert_parameter(self, pipe_stage=0):
        assert self._synchronizer is not None
        for name, param in self._parameters_to_sync[pipe_stage]:
            param, state = self._synchronizer.alltoall_routed_experts(
                name,
                param,
                self.tensor_and_expert_parallel_group()
            )
            if state:
                self._expert_sync_buffer.pop(name, "Not Found.")
                self._expert_sync_buffer[name] = param

    def allgather_routed_expert_parameter(self, group_name, pipe_stage=0):
        assert self._synchronizer is not None
        for name, param in self._parameters_to_sync[pipe_stage]:
            param, state = self._synchronizer.allgather_routed_experts(
                name,
                param,
                group_name,
                tp_rank=self.tensor_parallel_rank()
            )
            if state:
                self._expert_sync_buffer.pop(name, "Not Found.")
                self._expert_sync_buffer[name] = param

    def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
        """
        :meta private:
        """
        tensors = []
        for name, param in self._parameters_to_sync[pipe_stage]:
            if self._expert_sync_buffer and name in self._expert_sync_buffer and \
                    (self._synchronizer and self._synchronizer.is_parameter_changed):
                tensors.append(self._expert_sync_buffer[name])
            else:
                tensors.append(param.data)

        assert len(tensors) > 0
        dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
        debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
        tensor_changed = rank != src_rank

        for bucket in dense_buckets:
            coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed)

        for param in sparse_bucket:
            col.broadcast(param, src_rank, group_name)

    def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
        """
        Arguments:
            to_rank: receive rank in mapping from trainer to inference model.
            buffer_rank: index which tensors of sync buffer to be sended in stage2.
            rank: destination rank in communication group which enumerate receive ranks.
            src_rank: source rank in communication group. always 0.
            group_name: communication group name.
            pipe_stage: pipeline stage. default 0.
            stage2: bool. whether stage2 or not. default False.
        Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1
            stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)]
            stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)]

            For stage1 pair (0, 8):
                1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0.
                2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1.

                After (0, 8), to_rank 8 received tensor slices of 8 and 9.

            For stage2 pair (8, 9):
                1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0.
                2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1.
                In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer.
        """
        tensor_changed = rank != src_rank
        start = time.time()
        arguments = f"{to_rank}_{buffer_rank}_{rank}_{src_rank}_{group_name}_{pipe_stage}_{stage2}"

        if stage2:
            if tensor_changed:
                parameters_to_sync = self._parameters_to_recv[to_rank]
            else:
                parameters_to_sync = self._parameters_to_send
        else:
            if rank not in self._sync_dst_rank_to_src_ranks:
                self._sync_dst_rank_to_src_ranks.update({rank:[src_rank]})
                del self._sync_buffer
                self._sync_buffer = defaultdict(list)
            else:
                self._sync_dst_rank_to_src_ranks[rank].append(src_rank)
            parameters_to_sync = self._parameters_to_sync

        def tensor_generator():
            if stage2 and not tensor_changed and self._sync_buffer:# pylint: disable=too-many-nested-blocks
                idx = 0
                for name, param in parameters_to_sync[pipe_stage]:
                    value = self._sync_buffer[buffer_rank % self.tp_num_mapping][idx].cuda() # restore from cpu
                    self._logger.debug(
                        f"Adding {name}({value.shape}) to sync for if branch from "
                        f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}"
                    )
                    buffer_num = 1
                    idx += 1
                    yield value, buffer_num
                del self._sync_buffer[buffer_rank % self.tp_num_mapping]
            else:
                idx = 0
                for name, param in parameters_to_sync[pipe_stage]:
                    idx += 1
                    param_data = param.data
                    if rank and self._buffer_num and not stage2:
                        assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}"
                        buffer_num = self._buffer_num[name]
                    elif stage2:
                        buffer_num = 1
                    else:
                        if self._expert_sync_buffer and name in self._expert_sync_buffer:
                            param_data = self._expert_sync_buffer[name]
                            regroup_routed_experts = True # For routed experts in Qwen2vLLM
                        else:
                            regroup_routed_experts = False
                        # regroup src_tensor by tp_rank
                        param_data = self._synchronizer.regroup_params_to_sync(
                            name,
                            param_data,
                            self._tp_division[name],
                            regroup_routed_experts
                        )
                        # move self._expert_sync_buffer[name] to cpu mem to save gpu mem
                        if regroup_routed_experts and name in self._expert_sync_buffer:
                            cpu_expert = self._expert_sync_buffer[name].cpu()
                            del self._expert_sync_buffer[name]
                            self._expert_sync_buffer[name] = cpu_expert
                        buffer_num = 1
                    self._logger.debug(
                        f"Adding {name}({param_data.shape}) to sync for else branch from "
                        f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}"
                    )
                    yield param_data, buffer_num

        bucket_generator = bucket_tensors_two_stage_generator(
            tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb,
            stage2=stage2, tensor_changed=tensor_changed and not stage2
        )
        dense_bucket_num = 0
        sparse_bucket_num = 0
        for bucket_or_tensor, is_dense in bucket_generator:
            if is_dense:
                index = 0 if stage2 else (to_rank % self.tp_num_mapping)
                all_buffers = coalesced_comm_dense_two_stage(
                    bucket_or_tensor, col.broadcast, rank,
                    extra_args=(src_rank, group_name), tensor_changed=tensor_changed,
                    stage2=stage2, index=index)
                if tensor_changed and not stage2:
                    for key, value in all_buffers.items():
                        cpu_value = []
                        for tensor in value:
                            cpu_value.append(tensor.cpu().pin_memory()) # save gpu memory
                        del value
                        self._sync_buffer[key] += cpu_value
                    del all_buffers
                dense_bucket_num += 1
            else:
                col.broadcast(bucket_or_tensor, src_rank, group_name)
                sparse_bucket_num += 1

        if stage2:
            self._sync_dst_rank_to_src_ranks = {}

        self._logger.debug(f"broadcast_parameter_two_stage {arguments} done using {time.time()-start} seconds")
        debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, sparse_bucket {sparse_bucket_num}", self._logger)

    def send_parameter(self, dst_rank, group_name, pipe_stage=0):
        """
        :meta private:
        """
        self.send_recv_parameter(dst_rank, group_name, col.send, pipe_stage)

    def recv_parameter(self, src_rank, group_name, pipe_stage=0):
        """
        :meta private:
        """
        self.send_recv_parameter(src_rank, group_name, col.recv, pipe_stage)

    def ray_put_parameter(self, group_name, pipe_stage=0):
        """
        :meta private:
        """
        name2ref = {}
        tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]]
        dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
        debug_rank_0(f"{self.name} Put dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
        for bucket_id, bucket in enumerate(dense_buckets):
            flat_tensors = _flatten_dense_tensors(bucket)
            flat_tensors_ref = ray.put(flat_tensors)
            name2ref[group_name + ":dense_bucket_" + str(bucket_id)] = flat_tensors_ref
        for param_id, param in enumerate(sparse_bucket):
            param_ref = ray.put(param)
            name2ref[group_name + ":sparse_bucket_" + str(param_id)] = param_ref
        return name2ref

    def ray_get_parameter(self, group_name, name2ref, pipe_stage=0):
        """
        :meta private:
        """
        tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]]
        dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
        debug_rank_0(f"{self.name} Get dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
        for bucket_id, bucket in enumerate(dense_buckets):
            put_ref = name2ref[group_name + ":dense_bucket_" + str(bucket_id)]
            flat_tensors = ray.get(put_ref)
            for tensor, synced in zip(
                bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
                tensor.copy_(synced)
        for param_id, param in enumerate(sparse_bucket):
            put_ref = name2ref[group_name + ":sparse_bucket_" + str(param_id)]
            param.copy_(ray.get(put_ref))

    def pipeline_model_parallel_size(self):
        """
        :meta private:
        """
        return self.module_args.pipeline_model_parallel_size

    def tensor_model_parallel_size(self):
        """
        :meta private:
        """
        return self.module_args.tensor_model_parallel_size

    def expert_model_parallel_size(self):
        """
        :meta private:
        """
        return self.module_args.expert_model_parallel_size

    def num_layers(self):
        """
        :meta private:
        """

    def set_storage(self, storage):
        """
        :meta private:
        """
        self._storage = storage

    def timers(self, name):
        """
        :meta private:
        """
        if self._timers is None:
            self._timers = Timers()
        return self._timers(name)

    def timer_summary(self, e2e_cost=None):
        """
        :meta private:
        """
        if self._timers:
            return self._timers.log(return_dict=True, e2e_cost=e2e_cost)

    def get_and_clear_metrics(self):
        """
        get logging metrics
        """
        if self._metric_list is None or len(self._metric_list) == 0:
            return self._metric_prefix, {}

        reduced_metrics = map_reduce_metrics(self._metric_list)
        self._metric_list = []
        return self._metric_prefix, reduced_metrics

    def add_padding_config(self, key, padding_value=0.0, padding_type="right"):
        """
        Add spectial padding config for certain value.

        Args
        ----
        key: str
            The key for data to be padded.
        padding_value: float
            Padding value, default is 0.
        padding_type: str
            Default right, can be right/left.
        """
        self._padding_config[key] = {"padding_value": padding_value, "padding_type": padding_type}

    def padding_config(self):
        """
        :meta private:
        """
        return self._padding_config

    def peak_memory(self):
        """
        :meta private:
        """
        return 0.0

    @property
    def resume_training(self):
        """
        resume training from last checkpoint.
        """
        return self._resume_training

    def get_address(self):
        """
        Get node address

        :meta private:
        """
        return self._address

    def is_master_node(self):
        """
        Whether this node is master node.
        :meta private:
        """
        return self._is_master_node

    def set_src_parameter_model(self, src_model):
        """
        src_model that sync parameter to current model
        :meta private:
        """
        self._src_parameter_model = src_model

    @property
    def src_parameter_model(self):
        """
        src_model that sync parameter to current model
        """
        return self._src_parameter_model

    def offload_optimizer_states(self):
        """
        offload optimizer states
        """

    def onload_optimizer_states(self):
        """
        onload optimizer states
        """

    def offload_main_weights(self):
        """
        offload main weights
        """

    def onload_main_weights(self):
        """
        onload main weights
        """

    def offload_weights(self):
        """
        offload weights
        """

    def onload_weights(self):
        """
        onload weights
        """

    def free_grad_buffers(self):
        """
        free grad buffers and related tensors
        """

    def build_grad_buffers(self):
        """
        build grad buffers and related tensors
        """

    def onload(self):
        pass

    def offload(self):
        pass

    @property
    def world_size(self):
        pass

    @property
    def data_parallel_size(self):
        """
        data parallel size

        :meta private:
        """

    @property
    def data_parallel_rank(self):
        """
        data parallel rank

        :meta private:
        """

    def empty_cache(self):
        """
        :meta private:
        """

    def get_data_parallel_rank(self):
        return self.data_parallel_rank

    def get_data_parallel_size(self):
        return self.data_parallel_size

    def get_pipeline_stage_layer_num(self):
        pass

    def get_pipeline_stage_layer_offset(self):
        return 0

    def set_synchronizer(self, synchronizer):
        self._synchronizer = synchronizer

    def expert_parallel_rank(self):
        """
        :meta private:
        """
        return 0

    def enable_stage_resume(self, is_eval):
        """
        check whether to resume stage outputs.
        """
        if is_eval:
            return False
        if self.model_args.get("enable_stage_resume", False):
            assert self.runtime_args.data_checkpoint_path, \
                "data_checkpoint_path must be set for stage resume."
            return True
        return False

    def get_stage_outputs_path(self, iteration):
        """
        get path for stage outputs.
        """
        save_dir = self.runtime_args.data_checkpoint_path
        save_path = f"{save_dir}/{iteration}/{self.name}_replica_{self.replica_id}.pt"
        save_path_meta = f"{save_dir}/{iteration}/{self.name}_replica_{self.replica_id}_meta.txt"
        return save_path, save_path_meta

    def load_stage_outputs(self, is_eval, iteration):
        """
        load stage outputs for resume.
        """
        outputs = None
        # only load once for each launching.
        if self.enable_stage_resume(is_eval) and not self._stage_resume_done:
            self._stage_resume_done = True
            save_path, save_path_meta=self.get_stage_outputs_path(iteration)
            if os.path.exists(save_path) and os.path.exists(save_path_meta):
                try:
                    with open(save_path_meta, "r", encoding='utf-8') as f:
                        replica_id = int(f.readline())
                    if replica_id == self.replica_id:
                        outputs = torch.load(save_path)
                        logger.info(f"resume stage outputs for model:{self.name}, path:{save_path}")
                except ValueError:
                    logger.warning(f"ignore incomplete stage outputs, path:{save_path}")
        return outputs

    def save_stage_outputs(self, is_eval, outputs, iteration):
        """
        save stage outputs for resume.
        """
        if self.enable_stage_resume(is_eval):
            save_path, save_path_meta=self.get_stage_outputs_path(iteration)
            logger.info(f"Start to save stage outputs:{save_path}")
            Path(save_path).parent.mkdir(parents=True, exist_ok=True)
            torch.save(outputs, save_path)
            # save meta
            with open(save_path_meta, "w", encoding='utf-8') as f:
                f.write(f"{self.replica_id}")
            logger.info(f"Finished to save stage outputs:{save_path}")
