# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dist Actor"""

from collections import defaultdict
import importlib
import inspect
from functools import partial

import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from chatlearn.models.base_module import BaseModule
from chatlearn.utils import future
from chatlearn.utils.utils import parse_function_args

vllm_exist = importlib.util.find_spec("vllm")
if vllm_exist:
    from chatlearn.models.vllm_module import VLLMModule
    from chatlearn.models.vllm_module_v2 import VLLMModuleV2
    from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion

RAY_REMOTE = "remote"


class DistActor:
    """Manage a collection of actors"""

    def __init__(self, model: BaseModule,
                 gpu_per_node,
                 error_signal,
                 port_manager,
                 replica_id=0,
                 storage=None):
        self.total_gpu = model.total_gpu
        self.total_cpu = model.total_cpu
        self.gpu_per_process = model.gpu_per_process
        self.num_gpu_per_replica = model.num_gpu_per_replica
        self.trainable = model.trainable
        self.gpu_per_node = gpu_per_node
        self.model = model
        self.all_actors = []
        self.replica_id = replica_id
        self._port_manager = port_manager
        self.name = self.model.name
        self.error_signal = error_signal
        self.storage = storage
        # ranks for model update
        self.all_ranks = None
        self._init_done = False
        self._placement_group = None
        self.rank_to_actors = {}

    @property
    def module_args(self):
        return self.model.module_args

    @property
    def runtime_args(self):
        return self.model.runtime_args

    @property
    def master(self):
        return self.all_actors[0]

    @property
    def tailer(self):
        return self.all_actors[-1]

    @property
    def actor_num(self):
        return len(self.all_actors)

    def _get_func_args(self, func_name):
        func = getattr(self.model, func_name)
        return parse_function_args(func)

    def preprocess_actors(self):
        self.add_remote_func()

    def add_remote_func(self):
        for func_name, _ in inspect.getmembers(self.master):
            # ray.actor.ActorMethod
            if func_name.startswith('_'):
                continue
            dist_call = partial(self.call_remote_funcs, func_name)
            setattr(self, func_name, dist_call)

    def call_actor_remote_func(self, actor, func_name, *args, **kwargs):
        func = getattr(actor, func_name)
        remote_func = getattr(func, RAY_REMOTE)
        res = remote_func(*args, **kwargs)
        return res

    def call_remote_funcs(self, func_name, *args, **kwargs):
        """
        Call remote functions for a collection of actors.
        """
        results = []
        for actor in self.all_actors:
            res = self.call_actor_remote_func(actor, func_name, *args, **kwargs)
            results.append(res)
        return results

    def _create_actor(self, cls, num_gpus, placement_group, group_index, **kwargs):
        scheduling_strategy = PlacementGroupSchedulingStrategy(
            placement_group=placement_group,
            placement_group_bundle_index=group_index,
        )
        # use max_concurrency=1 to make sure only one task execute at one time
        actor = ray.remote(num_gpus=num_gpus, num_cpus=0)(cls) \
            .options(scheduling_strategy=scheduling_strategy) \
            .remote(self.model.name, self.model.global_args, self.replica_id, **kwargs)
        actor.set_error_signal.remote(self.error_signal)
        actor.set_storage.remote(self.storage)
        self.all_actors.append(actor)
        return actor

    def create_actor(self, num_gpus, placement_group, group_index):
        return self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)

    def _setup_collective_group(self, rank_offset, world_size, group_name, backend="nccl"):
        refs = []
        all_ranks = []
        for i, actor in enumerate(self.all_actors):
            rank = i + rank_offset
            ref = actor.setup_collective_group.remote(
                rank=rank,
                world_size=world_size,
                backend=backend,
                group_name=group_name)
            refs.append(ref)
            all_ranks.append(rank)
            self.rank_to_actors[rank] = actor
        self.all_ranks = all_ranks
        return refs

    def _setup_ranks(self, rank_offset):
        all_ranks = []
        for i, actor in enumerate(self.all_actors):
            rank = i + rank_offset
            all_ranks.append(rank)
            self.rank_to_actors[rank] = actor
        self.all_ranks = all_ranks

    def terminate(self):
        # terminate when catching exceptions
        for actor in self.all_actors:
            ray.kill(actor)

    @property
    def placement_group(self):
        return self._placement_group

    @placement_group.setter
    def placement_group(self, pg):
        self._placement_group = pg

    def group_dist_actors_by_tp_rank(self):
        self.dp_rank_to_actors = defaultdict(list)
        self.data_parallel_size = future.get(self.all_actors[0].get_data_parallel_size.remote())
        if self.data_parallel_size is None:
            self.data_parallel_size = 1
        dp_ranks = future.wait([actor.get_data_parallel_rank.remote() for actor in self.all_actors], return_output=True)
        for actor, dp_rank in zip(self.all_actors, dp_ranks):
            self.dp_rank_to_actors[dp_rank].append(actor)

    def set_dist_env(self, revert_placement=False):
        pass

    def __str__(self):
        return f"{self.__class__.__name__}({self.name})[{self.replica_id}]"

    def __repr__(self):
        return f'<{self.__class__.__name__}({self.name})[{self.replica_id}] object at {hex(id(self))}>'


class DistTorchActor(DistActor):
    """DistTorchActor"""

    def reorder_actors(self, actors, revert_placement=False):
        gpu_per_node = min(self.gpu_per_node, self.model.num_gpu_per_replica)
        ordered_actors = []
        count = 0
        actor_gpus = []
        for actor in actors:
            gpus = future.get(actor.get_visible_gpus.remote())
            count += len(gpus)
            actor_gpus.append((actor, gpus))
            if count == gpu_per_node:
                actor_gpus.sort(key=lambda x: x[1][0])
                if revert_placement:
                    actor_gpus.reverse()
                ordered_actors += [a[0] for a in actor_gpus]
                actor_gpus = []
                count = 0
        return ordered_actors

    def set_dist_env(self, revert_placement=False):
        self.all_actors = self.reorder_actors(self.all_actors, revert_placement)
        master_addr = future.get(self.master.get_address.remote())
        master_port = future.get(self._port_manager.get_free_port.remote(master_addr))

        world_size = self.actor_num
        env_config = {"MASTER_ADDR": master_addr, "MASTER_PORT": master_port, "WORLD_SIZE": world_size}
        ret = []
        for rank, actor in enumerate(self.all_actors):
            env_config["RANK"] = rank
            if self.model.gpu_per_process == 1:
                local_rank = 0
            else:
                local_rank = rank % self.model.gpu_per_process
            env_config["LOCAL_RANK"] = local_rank
            ret.append(actor.set_env.remote(env_config))
        status = sum(future.get(ret))
        assert status == world_size

class DistVLLMActor(DistTorchActor):
    """DistVLLMActor"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.vllm_engine = None

    def create_actor(self, num_gpus, placement_group, group_index):
        if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
            kwargs = {
                "worker_module_name": "vllm.worker.worker",
                "worker_class_name": "Worker",
                "worker_class_fn": None,
                "trust_remote_code": True,
            }
        else:
            kwargs = {
                "vllm_actor_type" : "worker"
            }
        self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs)

    def create_engine_actor(self, num_gpus, placement_group, group_index):
        self.vllm_engine = self._create_actor(self.model.__class__, num_gpus, placement_group, group_index)
        self.model.engine = self.vllm_engine

    def call_vllm_engine_remote_funcs(self, func_name, *args, **kwargs):
        """
        Call remote functions for vllm_engine.
        """
        results = []
        res = self.call_actor_remote_func(self.vllm_engine, func_name, *args, **kwargs)
        results.append(res)
        return results

    def call_vllm_engine_and_workers_remote_funcs(self, func_name, *args, **kwargs):
        """
        Call remote functions for vllm_engine + workers.
        """
        results = []
        for actor in self.all_actors:
            res = self.call_actor_remote_func(actor, func_name, *args, **kwargs)
            results.append(res)
        res = self.call_actor_remote_func(self.vllm_engine, func_name, *args, **kwargs)
        results.append(res)
        return results

    def add_remote_func(self):
        for func_name, _ in inspect.getmembers(self.master):
            # ray.actor.ActorMethod
            if func_name.startswith('_') or func_name in ["peak_memory"]:
                continue
            if func_name in ["timer_summary"]:
                dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
            elif func_name in ["onload", "offload"]:
                if func_name == "onload":
                    new_func_name = "onload_for_workers"
                else:
                    new_func_name = "offload_for_workers"
                dist_call = partial(self.call_vllm_engine_remote_funcs, new_func_name)
            elif func_name in ["model_setup"]:
                dist_call = partial(self.call_vllm_engine_and_workers_remote_funcs, func_name)
            elif func_name in ["get_and_clear_metrics"]:
                dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
            else: # needed to check for other call_funs.
                dist_call = partial(self.call_remote_funcs, func_name)
            setattr(self, func_name, dist_call)

    @property
    def master(self):
        return self.vllm_engine

    def peak_memory(self):
        return self.model.peak_memory()


class DistModel:
    """DistModel"""

    def __init__(self):
        self.replicas = []
        self.name = None
        self.rank_to_actors = {}
        self.register_func()
        self._is_colocate = False
        self._colocate_models = []

    def add_replica(self, replica):
        self.replicas.append(replica)
        self.name = replica.name

    @property
    def trainable(self):
        return self.replicas[0].trainable

    @property
    def module_args(self):
        return self.replicas[0].module_args

    @property
    def actor_num(self):
        return sum(len(dist_actor.all_actors) for dist_actor in self.replicas)

    @property
    def num_replica(self):
        return len(self.replicas)

    @property
    def total_gpu(self):
        return self.replicas[0].total_gpu

    @property
    def total_cpu(self):
        return self.replicas[0].total_cpu

    @property
    def num_gpu_per_replica(self):
        return self.replicas[0].num_gpu_per_replica

    @property
    def gpu_per_process(self):
        return self.replicas[0].gpu_per_process

    @property
    def is_colocate(self):
        return self._is_colocate

    @is_colocate.setter
    def is_colocate(self, flag):
        self._is_colocate = flag

    def get_actor(self, rank):
        # given rank, return the actor
        for dist_actor in self.replicas:
            if rank in dist_actor.rank_to_actors:
                return dist_actor.rank_to_actors[rank]

    def init(self):
        refs = []
        for dist_actor in self.replicas:
            refs.append(dist_actor.init())
        future.get(refs)

    def register_func(self):
        for func_name in ["model_setup",
                          "before_episode",
                          "after_episode",
                          "get_and_clear_metrics",
                          "validate",
                          "destroy_collective_group",
                          "terminate",
                          "peak_memory",
                          "empty_cache",
                          "set_start_iteration",
                          "offload",
                          "onload",
                          "eval",
                          "train",
                          "set_src_parameter_model",
                          "set_colocate"]:
            dist_call = partial(self.call_replica_func, func_name)
            setattr(self, func_name, dist_call)

    def call_replica_func(self, func, *args, **kwargs):
        refs = []
        for dist_actor in self.replicas:
            ref = getattr(dist_actor, func)(*args, **kwargs)
            if ref is not None:
                refs.append(ref)
        return refs

    def call_replica_serial_func(self, func, *args, **kwargs):
        results = []
        for dist_actor in self.replicas:
            ref = getattr(dist_actor, func)(*args, **kwargs)
            if ref is not None:
                res = future.get(ref)
                results.append(res)
        return results

    def set_colocate_models(self, models):
        self._colocate_models = models

    def colocate_with(self, model):
        return model in self._colocate_models

    @property
    def colocate_models(self):
        return self._colocate_models

    @property
    def all_ranks(self):
        return [dist_actor.all_ranks for dist_actor in self.replicas]

    @property
    def use_vllm_backend(self):
        return vllm_exist and isinstance(self.replicas[0].model, (VLLMModule, VLLMModuleV2))

    def group_dist_actors_by_tp_rank(self):
        for replica in self.replicas:
            replica.group_dist_actors_by_tp_rank()

    @property
    def enable_offload(self):
        return self.module_args.free_grad_buffers or self.module_args.offload_weights or \
            self.module_args.offload_optimizer_states

    def __str__(self):
        return f"{self.__class__.__name__}({self.name})"

    def __repr__(self):
        return f'<{self.__class__.__name__}({self.name}) object at {hex(id(self))}>'
