# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""module runtime decorator"""

import inspect
import traceback

import torch
from torch.cuda import nvtx
import ray

from chatlearn.models.vllm_module_v2 import VLLMModuleV2
from chatlearn.utils import future
from chatlearn.utils import utils
from chatlearn.utils.constant import CHATLEARN_REGROUP_TAG, INDEX_TAG
from chatlearn.utils.constant import LOG_START
from chatlearn.utils.global_vars import _EXIT_ACTOR_NAME, set_wrap_func
from chatlearn.utils.utils import execute
from chatlearn.utils.utils import regroup_by_concat_along_batch, slice_by_index_along_batch


def monitor_error(func, func_name):
    def inner(self, *args, **kwargs):
        try:
            return func(self, *args, **kwargs)
        except Exception as e:
            self._logger.exception(f"Catch exception ========= in {self.name} {func_name} {e}")
            exit_actor = ray.get_actor(_EXIT_ACTOR_NAME)
            traceback_msg =  f"{traceback.format_exc()}"
            address = self.get_address()
            ray.get(exit_actor.add_error_node_and_msg.remote(address, traceback_msg))
            future.wait(self.error_signal.set_address.remote(address))
            # for other error, we raise in the corresponding workers
            if self.is_master_node():
                for line in traceback_msg.split("\n"):
                    self._logger.exception(line)
                execute("ray stop")
                raise

    return inner


def timeit(func, func_name):
    def inner(self, *args, **kwargs):
        if self.runtime_args.nsys:
            nvtx.range_push(func_name)
        if self.is_last_rank():
            # for the class inherited from base, it may call multiple times, so use the first start time
            if not self.timers(func_name).started_:
                self.timers(func_name).start()
            ret = func(self, *args, **kwargs)
            self.timers(func_name).stop()
        else:
            ret = func(self, *args, **kwargs)
        if self.profiler is not None and self._iteration > 0 and self._iteration <=2 and self.replica_id == 0 \
            and func_name in ["forward_step", "train_step"]:
            self.profiler.step()
        if self.profiler is not None and self._iteration ==3 and self.replica_id == 0 and func_name in ["forward_step", "train_step"]:
            self.profiler.stop()
        if self.runtime_args.nsys:
            nvtx.range_pop()
        return ret

    return inner


def split_along_batch(batch, new_batch_size):
    assert isinstance(batch, (list, tuple, dict)), \
        "batch type {} is not supported".format(type(batch))
    if isinstance(batch, (list, tuple)):
        bs = len(batch[0])
        keys = range(len(batch))
    else:
        bs = len(next(iter(batch.values())))
        keys = batch.keys()

    accum_bs = 0
    new_batches = []
    while accum_bs < bs:
        if isinstance(batch, (list, tuple)):
            new_batch = [batch[key][accum_bs:min(accum_bs + new_batch_size, bs)] for key in keys]
        else:
            new_batch = {key: batch[key][accum_bs:min(accum_bs + new_batch_size, bs)] for key in keys}
        accum_bs += new_batch_size
        new_batches.append(new_batch)
    return new_batches

def concat_along_batch(tensors):
    batched = {}
    if tensors[0] is None:
        return batched

    for key in tensors[0].keys():
        to_batch = [results[key] for results in tensors]
        if isinstance(to_batch[0], torch.Tensor):
            batched[key] = torch.concat(to_batch)
        elif isinstance(to_batch[0], list):
            batched[key] = []
            for seq in to_batch:
                batched[key].extend(seq)
        else:
            raise Exception(f"unknown types key: {key} and {type(to_batch[0])} to concat")

    return batched


def preprocess_compute(func, trainable):
    """
    1. if not trainable, merge a list of dict into one dict, i.e., merge inputs of forward_step.
    2. split a list of data for data_parallel, this is used for train_step
    3. convert output to cpu
    """

    def inner(self, *args, **kwargs):
        args = future.get(args)
        assert isinstance(args, (list, tuple)), f"expect args is a list, while {type(args)}, args: {args}."
        batched_data_list = [None] * len(args)
        if not trainable:
            self._logger.info(f"{LOG_START} start to merge data for {self.name} replica {self.replica_id}.")
            self._logger.info(f"{LOG_START} preprocess_compute model {self.name} replica {self.replica_id} \
                has inputs from {len(args)} input node.")

            for idx, arg_obj in enumerate(args):
                batched_data_list[idx] = arg_obj
                if CHATLEARN_REGROUP_TAG in arg_obj:
                    batched_data_list[idx] = regroup_by_concat_along_batch(arg_obj[CHATLEARN_REGROUP_TAG])
                if INDEX_TAG in arg_obj:
                    batched_data_list[idx] = slice_by_index_along_batch(batched_data_list[idx], arg_obj[INDEX_TAG])
                assert isinstance(batched_data_list[idx], dict), \
                    f"expect output arg for {self.name} to be a dict, while {type(batched_data_list[idx])}, arg: {batched_data_list[idx]}"
            if all(isinstance(batched_data, dict) for batched_data in batched_data_list):
                merged = {}
                for batched_data in batched_data_list:
                    merged.update(batched_data)
                args = [merged]

            self._logger.info(f"{LOG_START} complete to merge data for {self.name}.")

        def get_kwarg(key):
            return kwargs.pop(key) if key in kwargs else False
        to_empty_cache = get_kwarg('to_empty_cache')
        to_onload = get_kwarg('to_onload')
        to_offload = get_kwarg('to_offload')
        is_last_batch = get_kwarg('is_last_batch')
        is_eval = get_kwarg('is_eval')

        if to_onload:
            if isinstance(self, VLLMModuleV2):
                self.onload_for_workers()
            else:
                self.onload()
        generation_batch_size = self.module_args.generation_batch_size
        final_results = None
        if not trainable and generation_batch_size:
            # split into micro-batches if generation_batch_size < input_batch, then concat the results
            # this happens when different models have difference batch sizes
            input_batch = 0
            if len(args) > 0:
                for value in args[0].values():
                    input_batch = len(value)
                    break
                input_data = args[0]
            else:
                input_data = None
            if generation_batch_size != -1 and input_data is not None and input_batch > generation_batch_size and not hasattr(self, 'generate_vllm'):
                args = list(args)
                batches = split_along_batch(input_data, generation_batch_size)
                results = []
                for batch in batches:
                    args[0] = batch
                    if 'iteration' in inspect.signature(func).parameters:
                        kwargs["iteration"] = self._iteration
                    ret = func(self, *args, **kwargs)
                    self._iteration += 1
                    ret = utils.to_device('cpu', ret)
                    results.append(ret)
                # for model with DP/EP, we need to return results from all ranks
                # for model with TP/PP, only return the results from last rank
                if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \
                        or isinstance(self, VLLMModuleV2):
                    final_results = concat_along_batch(results)
            else:
                if 'iteration' in inspect.signature(func).parameters:
                    kwargs["iteration"] = self._iteration
                ret = func(self, *args, **kwargs)
                ret = utils.to_device('cpu', ret)
                self._iteration += 1
                final_results = None
                # for model with DP/EP, we need to return results from all ranks
                # for model with TP/PP, only return the results from last rank
                if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \
                        or isinstance(self, VLLMModuleV2):
                    final_results = ret
        else:
            if 'iteration' in inspect.signature(func).parameters:
                kwargs["iteration"] = self._train_iteration
            self._train_iteration += 1
            ret = func(self, *args, **kwargs)
            ret = utils.to_device('cpu', ret)
            if self.is_last_rank():
                final_results = ret
        if to_empty_cache:
            if isinstance(self, VLLMModuleV2):
                self.empty_cuda_graph_for_workers()
                self.empty_cache_for_workers()
            else:
                self.empty_cache()
        if to_offload:
            if isinstance(self, VLLMModuleV2):
                self.offload_for_workers()
            else:
                self.offload()
        if is_last_batch and not is_eval:
            self.runtime_args.consumed_samples += self.runtime_args.sample_per_episode
        return final_results

    return inner


def decorate_class_func(cls, func_name, decorator, *args, **kwargs):
    if not hasattr(cls, func_name):
        return
    func = getattr(cls, func_name)
    if func.__qualname__.startswith(decorator.__name__):
        # already decorated
        # This usually occurs when one class inherits from another class,
        # for example, if 'reference' inherits from 'policy', then methods like 'offload_optimizer_states'
        # would be decorated in the base class, eliminating the need for repeated decoration.
        return
    new_func = decorator(func, *args, **kwargs)
    set_wrap_func(func, new_func)
    setattr(cls, func_name, new_func)
