# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""megatron to vllm synchronizer"""

from abc import abstractmethod
import operator
from functools import reduce
import ray.util.collective as col
import torch
from chatlearn.utils.constant import QwenVersion
from chatlearn.utils.utils import get_use_legacy_models
from chatlearn.utils.vllm_utils import fix_qwen_query_key_value_ordering
from chatlearn.utils.vllm_utils import split_attn_state
from chatlearn.utils.vllm_utils import Megatron2LlamaSyncMap, Megatron2QWenSyncMap, MCore2LlamaSyncMap
from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, get_megatron_version
from .base import BaseSync

class MegatronVllmSync(BaseSync):
    """Megatron to vllm sync"""

    def __init__(self, src_model, dst_model):
        super().__init__(src_model, dst_model)
        self.src_module_args = src_model.module_args
        self.dst_module_args = dst_model.module_args
        self.is_parameter_changed = True

    @abstractmethod
    def map_src_to_dst(self, src_names, src_pipe_layer_offset):
        """
        :meta private:
        """

    def _validate(self, sync_map):
        if sync_map.concat_params_dict is not None:
            if isinstance(sync_map.concat_params_dict, dict):
                assert "modules" in sync_map.concat_params_dict
                assert "dim" in sync_map.concat_params_dict
                assert isinstance(sync_map.concat_params_dict["modules"], list)
            else:
                raise RuntimeError(f"Expect concat_params_dict in {self} to be a dict or None, while {sync_map.concat_params_dict}.")

        if sync_map.to_fix_act_ordering_dict is not None:
            if isinstance(sync_map.to_fix_act_ordering_dict, dict):
                assert "modules" in sync_map.to_fix_act_ordering_dict
                assert "dim" in sync_map.to_fix_act_ordering_dict
                assert isinstance(sync_map.to_fix_act_ordering_dict["modules"], list)
            else:
                raise RuntimeError(f"Expect to_fix_act_ordering_dict in {self} to be a dict or None, while {sync_map.to_fix_act_ordering_dict}.")

        if sync_map.to_fix_qkv_ordering_dict is not None:
            if isinstance(sync_map.to_fix_qkv_ordering_dict, dict):
                assert "modules" in sync_map.to_fix_qkv_ordering_dict
                assert "layer_re" in sync_map.to_fix_qkv_ordering_dict
                assert isinstance(sync_map.to_fix_qkv_ordering_dict["modules"], list)
            else:
                raise RuntimeError(f"Expect to_fix_qkv_ordering_dict in {self} to be a dict or None, while {sync_map.to_fix_qkv_ordering_dict}.")

    def map_name_from_src_to_dst(self, send_actor, recv_actor, src_names, dst_names):
        src_pipe_layer_offset = self.get_or_cache(send_actor, "get_pipeline_stage_layer_offset")
        dst_pipe_layer_offset = self.get_or_cache(recv_actor, "get_pipeline_stage_layer_offset")
        self.sync_map = self.map_src_to_dst(src_names, src_pipe_layer_offset+dst_pipe_layer_offset)
        self._validate(self.sync_map)
        self.concat_params_dict = self.sync_map.concat_params_dict
        return self.sync_map.src_names, self.sync_map.dst_names

    def concat_params(self, params_to_sync_list):
        if self.sync_map.concat_params_dict is None:
            return params_to_sync_list
        concat_modules_list = self.sync_map.concat_params_dict["modules"]
        concat_dim = self.sync_map.concat_params_dict["dim"]
        params_to_sync_list_new = []
        concat = []
        for name, params in params_to_sync_list:
            if any(ele in name for ele in concat_modules_list):
                concat.append(params)
                if len(concat) == len(concat_modules_list):
                    params = torch.cat(concat, dim=concat_dim)
                    params_to_sync_list_new.append((name, params))
                    concat = []
            else:
                params_to_sync_list_new.append((name, params))
        return params_to_sync_list_new

    def fix_qkv_ordering(self, params_to_sync_list):
        to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict
        if to_fix_qkv_ordering_dict is None:
            return params_to_sync_list
        layer_re = self.sync_map.to_fix_qkv_ordering_dict["layer_re"]
        to_fix_modules_list = to_fix_qkv_ordering_dict["modules"]
        for i, (name, params_to_sync) in enumerate(params_to_sync_list):
            m = layer_re.match(name)
            if m is None:
                continue
            op_name = m.group(2)
            if op_name in to_fix_modules_list:
                checkpoint_version = 3.0
                tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
                heads = self.src_module_args.args_dict["num_attention_heads"] // tp_size
                hidden_size_per_head =  self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"]
                if self._to_fix_qkv_ordering_func is split_attn_state:
                    _num_query_groups = self.src_module_args.args_dict["num_query_groups"]//tp_size  \
                        if self.src_module_args.args_dict["group_query_attention"] else heads
                    params_to_sync = self._to_fix_qkv_ordering_func(
                        params_to_sync, heads, _num_query_groups, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"])
                    params_to_sync_list[i] = (name, params_to_sync)
                else:
                    input_shape = params_to_sync.size()
                    shape = (heads, hidden_size_per_head, 3) + input_shape[1:]
                    division = reduce(operator.mul, shape, 1)
                    num_elements = params_to_sync.numel()
                    if num_elements == division:
                        # model with gqa dont need to fix qkv ordering.
                        weight_or_bias = m.group(3)
                        params_to_sync = self._to_fix_qkv_ordering_func(
                            params_to_sync, checkpoint_version, 3, heads, hidden_size_per_head
                        )
                        if weight_or_bias == "weight":
                            params_to_sync = params_to_sync.contiguous()
                        params_to_sync_list[i] = (name, params_to_sync)
        return params_to_sync_list

    def fix_act_ordering(self, params_to_sync_list):
        if self.sync_map.to_fix_act_ordering_dict is None:
            return params_to_sync_list
        fix_dim = self.sync_map.to_fix_act_ordering_dict["dim"]
        to_fix_act_ordering_list = self.sync_map.to_fix_act_ordering_dict["modules"]
        for i, (name, params_to_sync) in enumerate(params_to_sync_list):
            if any([ele in name for ele in to_fix_act_ordering_list]): # pylint: disable=use-a-generator
                val = params_to_sync
                offset = val.shape[0] // 2
                w1 = val[:offset,:]
                w2 = val[offset:,:]
                params_to_sync = torch.cat([w2, w1], dim=fix_dim)
                params_to_sync_list[i] = (name, params_to_sync)
        return params_to_sync_list

    def fix_shared_expert_ordering(self, params_to_sync_list):
        if self.sync_map.to_fix_shared_expert_ordering is None:
            return params_to_sync_list
        fix_dim = self.sync_map.to_fix_shared_expert_ordering["dim"]
        to_fix_shared_expert_ordering_list = self.sync_map.to_fix_shared_expert_ordering["modules"]
        for i, (name, params_to_sync) in enumerate(params_to_sync_list):
            if any([ele in name for ele in to_fix_shared_expert_ordering_list]): # pylint: disable=use-a-generator
                w1, w2 = params_to_sync.chunk(2, dim=0)
                params_to_sync = torch.cat([w2, w1], dim=fix_dim).contiguous()
                params_to_sync_list[i] = (name, params_to_sync)
        return params_to_sync_list

    def allgather_routed_experts_from_hep(self, name, params_to_sync, group_name, tp_rank):
        """
        This function is applicable for synchronizing parameters from QWen with HEP enabled
        to vLLM. In HEP, routed experts are split into a total number of EP size * TP size.
        Thus, the function will all-gather across EP size * TP size routed experts and slice
        them to TP size partitions.
        """
        if self.sync_map._to_allgather_routed_experts_dict is None:
            return params_to_sync, False

        to_allgather_routed_experts_dict = self.sync_map._to_allgather_routed_experts_dict
        layer_re = to_allgather_routed_experts_dict["layer_re"]
        to_regroup_modules_list = to_allgather_routed_experts_dict["modules"]

        m = layer_re.match(name)
        if m is None:
            return params_to_sync, False

        op_name = m.group(2)
        if op_name in to_regroup_modules_list:
            tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
            ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"]
            hep_size = tp_size * ep_size
            moe_num_experts = self.src_module_args.args_dict["moe_num_experts"]
            local_num_experts = moe_num_experts // hep_size
            hidden_size = self.src_module_args.args_dict["hidden_size"]
            output_tensor_list = [
                torch.empty(size=params_to_sync.shape, dtype=params_to_sync.dtype, device=params_to_sync.device)
                for _ in range(hep_size)
            ]
            col.allgather(output_tensor_list, params_to_sync, group_name)
            del params_to_sync
            val_list = []
            if "dense_h_to_4h" in op_name:
                # w13_weight
                while output_tensor_list:
                    params = output_tensor_list.pop(0)
                    # regroup among difference tp slices
                    params = params.view((moe_num_experts, -1, hidden_size)).contiguous()
                    params = params.reshape((local_num_experts * 2, -1, hidden_size))
                    params = params.chunk(tp_size, dim=1)[tp_rank]
                    # reorder w1 and w3
                    params = params.reshape(params.shape[0] // 2, -1, hidden_size)
                    params_right, params_left = params.chunk(2, dim=1)
                    del params
                    params = torch.cat([params_left, params_right], dim=1)
                    del params_left
                    del params_right
                    val_list.append(params)
                params_to_sync = torch.cat(val_list, dim=0).contiguous()
            else:
                # w2_weight
                while output_tensor_list:
                    params = output_tensor_list.pop(0)
                    params = params.reshape((local_num_experts, -1, hidden_size))
                    chunked_params = params.chunk(tp_size, dim=1)[tp_rank].contiguous()
                    del params
                    val_list.append(chunked_params)
                params_to_sync = torch.cat(val_list, dim=0).transpose(1, 2).contiguous()
            del val_list
            return params_to_sync, True
        else:
            return params_to_sync, False

    def allgather_routed_experts(self, name, params_to_sync, group_name, tp_rank): # pylint: disable=unused-argument
        megatron_version = get_megatron_version()
        if megatron_version == MegatronVersion.V4:
            return self.allgather_routed_experts_from_hep(name, params_to_sync, group_name, tp_rank)
        else:
            raise NotImplementedError(
                "ChatLearn does not support all-gathering routed experts for Megatron-LM, but supports QWen with HEP enabled. "
                "Please export `QWEN_VERSION` as `qwen_moe_v1`."
            )

    def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group):
        """
        This function is applicable for synchronizing parameters from QWen with HEP enabled
        to vLLM. In HEP, routed experts are split into a total number of EP size * TP size.
        Thus, the function will all-to-all across EP size * TP size routed experts.
        """
        if self.sync_map._to_alltoall_routed_experts_dict is None:
            return params_to_sync, False

        to_alltoall_routed_experts_dict = self.sync_map._to_alltoall_routed_experts_dict
        layer_re = to_alltoall_routed_experts_dict["layer_re"]
        to_regroup_modules_list = to_alltoall_routed_experts_dict["modules"]

        m = layer_re.match(name)
        if m is None:
            return params_to_sync, False

        op_name = m.group(2)
        if op_name in to_regroup_modules_list:
            tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
            ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"]
            hep_size = tp_size * ep_size
            moe_num_experts = self.src_module_args.args_dict["moe_num_experts"]

            local_num_experts = moe_num_experts // hep_size
            hidden_size = self.src_module_args.args_dict["hidden_size"]
            if "dense_h_to_4h" in op_name:
                # w13_weight
                # regroup among difference tp slices
                param = params_to_sync.view((moe_num_experts, -1, hidden_size))
                param = param.reshape((local_num_experts * 2, -1, hidden_size))
                params = list(param.chunk(hep_size, dim=1))
                # reorder w1 and w3
                params_list = []
                while params:
                    param = params.pop(0)
                    param = param.reshape(param.shape[0] // 2, -1, hidden_size)
                    param_right, param_left = param.chunk(2, dim=1)
                    del param
                    param = torch.cat([param_left, param_right], dim=1)
                    del param_left
                    del param_right
                    params_list.append(param)
                del params_to_sync
                output = [
                    torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device)
                    for i in range(hep_size)
                ]
                torch.distributed.all_to_all(output, params_list, group=comm_group)
                del params_list
                params_to_sync = torch.cat(output, dim=0).contiguous()
                del output
            else:
                # w2_weight
                param = params_to_sync.view((local_num_experts, -1, hidden_size))
                params = list(param.chunk(hep_size, dim=1))
                params_list = [ele.contiguous() for ele in params]
                del param
                del params
                del params_to_sync
                output = [
                    torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device)
                    for i in range(hep_size)
                ]
                torch.distributed.all_to_all(output, params_list, group=comm_group)
                del params_list
                params_to_sync = torch.cat(output, dim=0).transpose(1, 2).contiguous()
                del output
            return params_to_sync, True
        else:
            return params_to_sync, False

    def alltoall_routed_experts(self, name, params_to_sync, comm_group): # pylint: disable=unused-argument
        megatron_version = get_megatron_version()
        if megatron_version == MegatronVersion.V4:
            return self.alltoall_routed_experts_from_hep(name, params_to_sync, comm_group)
        else:
            raise NotImplementedError(
                "ChatLearn does not support all-to-all routed experts for Megatron-LM, but supports QWen with HEP enabled. "
                "Please export `QWEN_VERSION` as `qwen_moe_v1`."
            )

    def transform_parameters(self, params_to_sync_list):
        """
        transform parameters, e.g. concat, fix ordering
        """
        params_to_sync_list = self.concat_params(params_to_sync_list)
        params_to_sync_list = self.fix_act_ordering(params_to_sync_list)
        params_to_sync_list = self.fix_qkv_ordering(params_to_sync_list)
        params_to_sync_list = self.fix_shared_expert_ordering(params_to_sync_list)
        return params_to_sync_list

    def regroup_qkv_tp_slices(self, name, param_data, tp_division):
        param_data_shape = param_data.shape
        # Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
        to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict
        # pylint: disable=too-many-nested-blocks
        if "attention.query_key_value" in name or \
                "self_attention.query_key_value" in name or \
                "self_attention.linear_qkv" in name:
            src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
            dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"]
            heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size
            hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"]

            param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
            division = reduce(operator.mul, param_shape, 1)
            num_elements = param_data.numel()
            if num_elements == division:
                if to_fix_qkv_ordering_dict is not None:
                    param_data = param_data.view(param_shape)
                    param_data_list = []
                    head_offset = heads // tp_division
                    for idx in range(tp_division):
                        start = idx * head_offset
                        end = start + head_offset
                        param_data_list.append(param_data[:,start:end])
                    param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
                    del param_data_list
            else:
                if self.src_module_args.args_dict["group_query_attention"]:
                    num_query_groups = self.src_module_args.args_dict["num_query_groups"]
                    assert num_query_groups == self.dst_module_args.args_dict["num_query_groups"], (
                        f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of "
                        f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config."
                    )
                    src_num_query_groups_per_replica = num_query_groups // src_tp_size
                    if dst_tp_size >= num_query_groups:
                        num_dst_kv_head_replicas = dst_tp_size // num_query_groups
                    else:
                        num_dst_kv_head_replicas = 1
                else:
                    src_num_query_groups_per_replica = heads
                    num_dst_kv_head_replicas = 1

                if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1:
                    if len(param_data_shape) == 1:
                        param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head))
                    else:
                        param_data = param_data.view(
                            (heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"]))
                    param_data_list = []
                    head_offset = heads // tp_division
                    for idx in range(tp_division):
                        q_start = idx * head_offset
                        q_end = q_start + head_offset
                        if num_dst_kv_head_replicas == 1:
                            if src_num_query_groups_per_replica > tp_division:
                                assert src_num_query_groups_per_replica % tp_division == 0, (
                                    f"num_query_groups per replica of src model ({src_num_query_groups_per_replica}) "
                                    f"must be divisible by tp_division ({tp_division}). Please double-check your config."
                                )
                                kv_offset = src_num_query_groups_per_replica // tp_division
                            else:
                                kv_offset = 1
                            k_start = (heads + idx) if src_num_query_groups_per_replica // tp_division else heads
                            k_end = k_start + kv_offset
                            v_start = k_start + src_num_query_groups_per_replica
                            v_end = v_start + kv_offset
                        else:
                            k_start = heads + idx // num_dst_kv_head_replicas
                            k_end = k_start + 1
                            v_start = k_start + src_num_query_groups_per_replica
                            v_end = v_start + 1

                        q_proj = param_data[q_start:q_end].contiguous()
                        k_proj = param_data[k_start:k_end].contiguous()
                        v_proj = param_data[v_start:v_end].contiguous()

                        qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)

                        if len(param_data_shape) == 1:
                            qkv_proj = qkv_proj.reshape(-1).contiguous()
                        else:
                            qkv_proj = qkv_proj.reshape(-1, self.src_module_args.args_dict["hidden_size"]).contiguous()

                        param_data_list.append(qkv_proj)
                    param_data = torch.concat(param_data_list, dim=0)
                    del param_data_list
        return param_data

    def regroup_params_to_sync(self, name, param_data, tp_division, regroup_routed_experts=False):
        param_data = self.regroup_qkv_tp_slices(name, param_data, tp_division)
        return super().regroup_params_to_sync(name, param_data, tp_division, regroup_routed_experts)

class MegatronVllmQWenSync(MegatronVllmSync):
    """qwen"""

    def map_src_to_dst(self, src_names, src_pipe_layer_offset):
        """
        :meta private:
        """
        self._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering
        return Megatron2QWenSyncMap(src_names, src_pipe_layer_offset, QwenVersion.v_1.value)


class MegatronVllmQWen2Sync(MegatronVllmSync):
    """qwen2"""

    def map_src_to_dst(self, src_names, src_pipe_layer_offset):
        self._to_fix_qkv_ordering_func = split_attn_state
        return Megatron2QWenSyncMap(src_names, src_pipe_layer_offset, QwenVersion.v_2.value)


class MegatronVllmLlamaSync(MegatronVllmSync):
    """llama"""

    def map_src_to_dst(self, src_names, src_pipe_layer_offset):
        use_legacy_models = get_use_legacy_models(self.src_model.module_args.args_dict)
        sync_map_cls = Megatron2LlamaSyncMap if use_legacy_models else MCore2LlamaSyncMap
        self._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering
        return sync_map_cls(src_names, src_pipe_layer_offset)
