# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""model manager"""

import concurrent.futures
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import time

import ray
import ray.experimental.state.api

from chatlearn.data.storage import Storage
from chatlearn.launcher import dlc_utils
from chatlearn import FSDPModule
from chatlearn.models.torch_module import TorchModule
from chatlearn.models.vllm_module_v2 import VLLMModuleV2
from chatlearn.runtime.decorator import decorate_class_func
from chatlearn.runtime.decorator import timeit, preprocess_compute, monitor_error
from chatlearn.runtime.dist_actor import DistActor, DistTorchActor, DistVLLMActor, DistModel
from chatlearn.synchronizer.parameter_sync import ParameterSyncGroup, ParameterSyncGroupwithHEP
from chatlearn.synchronizer.parameter_sync_fsdp import FSDP2VllmParameterSyncGroup
from chatlearn.utils.constant import LOG_START
from chatlearn.utils.error_monitor import ErrorMonitor, ErrorSignalActor
from chatlearn.utils.logger import logger
from chatlearn.utils.global_vars import set_decorated, is_decorated
from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, get_megatron_version
from .port_manager import PortManager
from ..utils import future


class ModelManager:
    """ModelManager"""

    def __init__(self, models, resouce_manager, global_args):
        self.local_models = models
        self.resouce_manager = resouce_manager
        self.dist_models = []
        self.env_args = global_args.env_args
        self.runtime_args = global_args.runtime_args
        self.converted = False
        # port for DLC jobs, the first two ports are reserved for ray start
        self.free_ports = dlc_utils.get_free_ports()[2:]
        self._port_manager = PortManager.remote(self.free_ports)
        self.error_signal = ErrorSignalActor.remote()
        self._storage = Storage.remote()
        self.parameter_sync_groups = {}
        self._parameter_sync_model_pair = []
        self.model_packs = []
        self.placement_groups = []

    def _get_total_gpu_required(self):
        total_gpu = 0
        remote_states = set()
        for group in self.runtime_args.colocation:
            colocate_models = [self._name2distmodel[name] for name in group]
            max_gpu = max(m.total_gpu for m in colocate_models)
            total_gpu += max_gpu
            for name in group:
                remote_states.add(name)
        for model in self.dist_models:
            # place non-colocate models
            if model.name not in remote_states:
                max_gpu = model.total_gpu
                total_gpu += max_gpu
        return total_gpu

    def remote(self) -> list:
        """
        convert model to remote
        """
        logger.info(f"{LOG_START} model_manager start to convert model to remote")
        t1 = time.time()
        if self.converted:
            return self.dist_models

        self._name2distmodel = {}
        remote_states = set()
        for model in self.local_models:
            # create dist model object for each local model
            dist_model = self._to_dist_model(model)
            self.dist_models.append(dist_model)
            self._name2distmodel[model.name] = dist_model
        total_gpu_required = self._get_total_gpu_required()
        if total_gpu_required > self.resouce_manager.total_gpu:
            raise RuntimeError(f"The number of required gpus for current job is {total_gpu_required}, " + \
                               f"while the number of applied gpus is {self.resouce_manager.total_gpu}")
        if self.resouce_manager.total_gpu > total_gpu_required:
            logger.warning(f"The number of applied gpus is {self.resouce_manager.total_gpu}, " + \
                           f"while the number of required gpus is {total_gpu_required}, " + \
                           f"there is {self.resouce_manager.total_gpu - total_gpu_required} wasted gpus")

        t2 = time.time()
        logger.info(f"{LOG_START} model_manager convert model to remote, get_total_gpu_required(s):{(t2-t1)}")
        env_list = []
        for group in self.runtime_args.colocation:
            colocate_models = [self._name2distmodel[name] for name in group]
            self.place_models_to_remote_devices(colocate_models, env_list)
            if len(colocate_models) > 1:
                set_colocate = []
                for model in colocate_models:
                    model.is_colocate = True
                    set_colocate.extend(model.set_colocate(True))
                future.wait(set_colocate)
            for name in group:
                remote_states.add(name)
        t3 = time.time()
        logger.info(f"{LOG_START} model_manager convert model to remote, set_colocate(s):{(t3-t2)}")
        for model in self.dist_models:
            # place non-colocate models
            if model.name not in remote_states:
                self.place_models_to_remote_devices([model], env_list)
        self.set_dist_env_concurrent(env_list)
        self.converted = True
        t4 = time.time()
        logger.info(f"{LOG_START} model_manager convert model to remote, place_models_to_remote_devices(s):{(t4-t3)}")
        return self.dist_models

    def build_parameter_group(self):
        # set ParameterSyncGroup
        megatron_version = get_megatron_version()
        for src_model, dst_model in self._parameter_sync_model_pair:
            logger.info(
                f"start build parameter sync group bewteen {src_model.name} and {dst_model.name}")
            group_name = self._get_group_name(src_model, dst_model)
            sync_frequency = self._get_sync_frequency(dst_model)

            if isinstance(self._name2distmodel[src_model.name].replicas[0].model, FSDPModule):
                sync_group = FSDP2VllmParameterSyncGroup(
                    self._name2distmodel[src_model.name],
                    self._name2distmodel[dst_model.name],
                    group_name,
                    sync_frequency,
                    self.error_signal
                )
            elif megatron_version == MegatronVersion.V4:
                logger.info("QWEN_VERSION has been set to qwen_moe_v1, where HEP is enabled.")
                sync_group = ParameterSyncGroupwithHEP(
                    self._name2distmodel[src_model.name],
                    self._name2distmodel[dst_model.name],
                    group_name,
                    sync_frequency,
                    self.error_signal
                )
            else:
                sync_group = ParameterSyncGroup(
                    self._name2distmodel[src_model.name],
                    self._name2distmodel[dst_model.name],
                    group_name,
                    sync_frequency,
                    self.error_signal
                )
            self.parameter_sync_groups[group_name] = sync_group

    def start_error_monitor(self):
        group_names = list(self.parameter_sync_groups.keys())
        self.error_monitor = ErrorMonitor.remote(self.error_signal, self.dist_models, group_names)
        self.error_monitor.monitor.remote()

    def _get_group_name(self, src_model, dst_model):
        return src_model.name + "2" + dst_model.name

    def _get_sync_frequency(self, model):
        return model.parameter_sync_frequency

    def set_parameter_sync(self, src_model, tgt_model):
        group_name = self._get_group_name(src_model, tgt_model)
        if group_name in self.parameter_sync_groups:
            logger.warning(f"{group_name} already set, ignore")
        else:
            sync_frequency = self._get_sync_frequency(tgt_model)
            assert sync_frequency >= 0, \
                f"parameter sync frequency from {src_model.name} to {tgt_model.name} expected tp be greater than 0, while {sync_frequency}."
            logger.info(f"sync parameters from {src_model.name} to {tgt_model.name} every {sync_frequency} episodes.")
            self._parameter_sync_model_pair.append((src_model, tgt_model))

    def warmup_collective_topology(self):
        for _, sync_group in self.parameter_sync_groups.items():
            sync_group.warmup_groups()

    def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False, dryrun=False):
        """
        if requires_grad is False, all parameters will be syncronized,
        this happends when broadcast parameters in the beginning of training,
        set the parameters of inference same as training
        """
        for _, sync_group in self.parameter_sync_groups.items():
            if sync_group.frequency and \
                    episode_offset % sync_group.frequency == 0:
                sync_group: ParameterSyncGroup = sync_group

                # src_model, dst_model type: DistModel
                src_model, dst_model = sync_group.src_model, sync_group.dst_model
                future.wait(src_model.onload(
                    to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False))
                future.wait(dst_model.onload(
                    to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False))

                sync_group.sync(requires_grad, validate, dryrun=dryrun)

                future.wait(src_model.offload())
                future.wait(dst_model.offload())

    def set_func_decorator(self, model):
        if is_decorated(model.name):
            return
        call_funcs = model.call_funcs

        model_cls = model.__class__
        for func_name in call_funcs:
            trainable = func_name in model.trainable_funcs
            decorate_class_func(model_cls, func_name,preprocess_compute, trainable)

        for func_name in ["save_checkpoint", "model_setup"] + call_funcs:
            decorate_class_func(model_cls, func_name, timeit, func_name)

        # public user function
        # TODO: use decorator to annotate
        for func_name in ["save_checkpoint", "model_setup", "onload", "offload", "build_dataset",
                          "_build_dataloader", "generate_vllm", "init"] + call_funcs:
            decorate_class_func(model_cls, func_name, monitor_error, func_name)
        set_decorated(model.name)

    def _to_dist_model(self, model):
        """
        Convert one model to DistActor and place it to devices

        Args:
            model: BaseModule
        """
        self.set_func_decorator(model)
        model.finalize()

        def actor_type():
            if isinstance(model, VLLMModuleV2):
                return DistVLLMActor
            if isinstance(model, TorchModule):
                return DistTorchActor
            return DistActor

        dist_model = DistModel()
        for replica_id in range(model.num_replica):
            dist_actor = actor_type()(model, self.resouce_manager.gpu_per_node, self.error_signal, self._port_manager,
                                      replica_id, self._storage)
            dist_model.add_replica(dist_actor)
        return dist_model

    def _find_param_recv_models(self, models):
        """
        find models that recv parameters
        """
        if len(models) < 2:
            return []
        model_names = [model.name for model in models]
        models_to_revert = []
        for model in models:
            for src, tgt in self._parameter_sync_model_pair:
                if src.name in model_names and model.name == tgt.name:
                    models_to_revert.append(model)
        return models_to_revert

    def find_model_packing_strategy(self, models, total_gpu):
        """
        Find model packing strategies that can pack all models into total_gpu
        try to balance the models among devices, i.e., each device holds similar number of model parts
        e.g., given models A:8, B:4, C:4, total_gpu: 8
        then the pack strategy is [(A), (B,C)]
        """
        sorted_models = sorted(models, key=lambda x: (x.trainable, x.total_gpu), reverse=True)
        assert sorted_models[0].total_gpu <= total_gpu
        final_packs = []
        # key is the remaining gpu
        unfinished_packs = defaultdict(list)

        for model in sorted_models:
            gpu = model.total_gpu
            if gpu == total_gpu:
                final_packs.append([model])
            else:
                if gpu in unfinished_packs:
                    # find a pack
                    packs = unfinished_packs[gpu].pop(0)
                    if len(unfinished_packs[gpu]) == 0:
                        unfinished_packs.pop(gpu)
                    packs.append(model)
                    final_packs.append(packs)
                else:
                    near_gpus = [d for d in unfinished_packs if d > gpu]

                    if near_gpus:
                        near_gpu = sorted(near_gpus)[0]
                        packs = unfinished_packs[near_gpu].pop(0)

                        if len(unfinished_packs[gpu]) == 0:
                            unfinished_packs.pop(gpu)
                        packs.append(model)
                        # update the remaining gpu number
                        unfinished_packs[near_gpu - gpu].append(packs)
                    else:
                        # add model and wait for packing
                        unfinished_packs[total_gpu - gpu].append([model])
        for gpu, packs_list in unfinished_packs.items():
            if packs_list:
                final_packs.extend(packs_list)
        return final_packs

    def place_gpu_models(self, gpu_models, env_list=None):
        """ place DistModel to gpu
        GPU models: Lis[DistModel]
        """
        if not gpu_models:
            return
        max_gpu = max(m.total_gpu for m in gpu_models)

        # create placement groups
        placement_group = self.resouce_manager.create_placement_group(max_gpu)
        for i, _ in enumerate(placement_group.bundle_specs):
            self.placement_groups.append((placement_group, i))
        models_str = ','.join([model.name for model in gpu_models])
        logger.info(f"create placement_group {placement_group.bundle_specs} for model {models_str} done")
        for model in gpu_models:
            # TODO: for colocate gpu_per_process > 1, support later
            assert model.gpu_per_process == 1
        self.model_packs = self.find_model_packing_strategy(gpu_models, max_gpu)

        for model in gpu_models:
            pack = []
            for pack in self.model_packs:
                if model in pack:
                    break
            colocate_models = []
            for model2 in gpu_models:
                if model2 is not model and model2 not in pack:
                    colocate_models.append(model2)
            model.set_colocate_models(colocate_models)

        def _get_model_replica_from_pack(gpu_index, model_pack):
            # for gpu rank between N * model.num_gpu_per_replica to (N + 1) *  model.num_gpu_per_replica
            # this function will return the same replica
            gpu_offset = 0
            for model in model_pack:
                if gpu_index < gpu_offset + model.total_gpu:
                    # compute the model rank
                    model_rank = gpu_index - gpu_offset
                    replica_id = model_rank // model.num_gpu_per_replica
                    return model.replicas[replica_id]
                gpu_offset += model.total_gpu
        # 1. we list the models to place on each device
        # 2. for device i, the number of models is N, then the num_gpus for each ray actor is 1.0/N
        # replica here is DistActor
        gpu_to_replicas = []
        for i in range(max_gpu):
            colocate_models = []
            for model_pack in self.model_packs:
                replica = _get_model_replica_from_pack(i, model_pack)
                if replica is not None:
                    colocate_models.append(replica)
            gpu_to_replicas.append(colocate_models)

        # For each gpu rank, create actor for each replica
        for i, replicas in enumerate(gpu_to_replicas):
            group = i // self.resouce_manager.gpu_per_node
            for replica in replicas:
                num_gpus = 1.0 / len(replicas)
                if isinstance(replica.model, VLLMModuleV2) and replica.vllm_engine is None:
                    num_gpus = num_gpus / 2
                    replica.create_engine_actor(num_gpus, placement_group, group)
                    # we do not want to add engine actor to all_actors
                    replica.all_actors.pop()
                replica.create_actor(num_gpus, placement_group, group)
        models_to_revert = self._find_param_recv_models(gpu_models)
        for model in gpu_models:
            if model in models_to_revert: # pylint: disable=simplifiable-if-statement
                # Reverse the placement of tgt models, so that shared models not in the same GPU
                # NCCL limit: NCCL WARN Duplicate GPU detected : rank 1 and rank 0 both on CUDA device
                # TODO: One GPU task still not work
                reverse_gpu_placement = True
            else:
                reverse_gpu_placement = False
            if env_list is None:
                for replica in model.replicas:
                    replica.set_dist_env(reverse_gpu_placement)
            else:
                env_list.append((model, reverse_gpu_placement))

    def place_cpu_models(self, cpu_models):
        if not cpu_models:
            return
        num_cpus = []
        for model in cpu_models:
            for _ in range(model.module_args.num_replica):
                num_cpus.append(model.module_args.cpu_per_process)
        if not self.placement_groups:
            placement_group = self.resouce_manager.create_placement_group(num_gpus=0, num_cpus=num_cpus, \
                                                                          strategy=self.runtime_args.cpu_schedule_strategy)
            models_str = ','.join([model.name for model in cpu_models])
            logger.info(f"create placement_group {placement_group.bundle_specs} for model {models_str} done")
            placement_groups = []
            for i, _ in enumerate(placement_group.bundle_specs):
                placement_groups.append((placement_group, i))
        else:
            placement_groups = self.placement_groups

        i = 0
        for cpu_model in cpu_models:
            for replica in cpu_model.replicas:
                pg, index = placement_groups[i]
                replica.create_actor(0, pg, index)
                i = i + 1
                if i >= len(placement_groups):
                    i = 0

    def place_models_to_remote_devices(self, models, env_list=None):
        cpu_models = [model for model in models if model.total_gpu == 0]
        gpu_models = [model for model in models if model.total_gpu > 0]
        self.place_gpu_models(gpu_models, env_list)
        self.place_cpu_models(cpu_models)

        # DistActor.preprocess_actors will add remote call for each function in Actor
        for model in models:
            for replica in model.replicas:
                replica.preprocess_actors()

    def _set_dist_env(self, model, reverse):
        for replica in model.replicas:
            replica.set_dist_env(reverse)

    def set_dist_env_concurrent(self, env_list):
        num = len(env_list)
        if num == 0:
            return
        with ThreadPoolExecutor(max_workers=num) as executor:
            futures = []
            for model, reverse in env_list:
                # set env
                futures.append(executor.submit(self._set_dist_env, model, reverse))
            for _future in concurrent.futures.as_completed(futures):
                try:
                    _future.result()
                except Exception as e:
                    raise RuntimeError(f"Set dist env generated an exception: {e}") # pylint: disable=raise-missing-from
            concurrent.futures.wait(futures)

    def clean(self):
        for group in self.parameter_sync_groups.values():
            group.destroy_collective_group()
        for dist_model in self._name2distmodel.values():
            for dist_actor in dist_model.replicas:
                for actor in dist_actor.all_actors:
                    try:
                        ray.kill(actor)
                    except Exception:
                        logger.info("Encountering exceptions in cleaning actors, but ok")
                        continue
        ray.kill(self._storage)
        ray.kill(self.error_signal)
        self.resouce_manager.remove_placement_groups()
