# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""FSDP  module"""
import os
import random
import functools
import gc

import ray
import numpy as np
import torch
import torch.distributed as dist
from torch import optim, nn
from torch.distributed.fsdp import (MixedPrecision, ShardingStrategy, ShardedOptimStateDictConfig, ShardedStateDictConfig,
                                    FullStateDictConfig, StateDictType, FullyShardedDataParallel as FSDP)
from torch.distributed.fsdp.wrap import (size_based_auto_wrap_policy, transformer_auto_wrap_policy, lambda_auto_wrap_policy,
                                        _or_policy)
from torch.distributed.fsdp._runtime_utils import _lazy_init
from torch.multiprocessing.reductions import reduce_tensor

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.trainer_pt_utils import get_module_class_from_name

from chatlearn.utils.logger import debug_rank_0, log_rank_0
from chatlearn.utils.utils import dict_to_simplenamespace

from .torch_module import TorchModule


class FSDPModule(TorchModule):
    """TorchModule is the class for Alignment Torch models.

    Args
    ----
    name : str
        model name
    """

    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)

        if not self.trainable:
            # inference only
            if (
                self.model_args.get("train_micro_batch_size")
                != self.module_args.generation_batch_size
            ):
                self._logger.info(
                    f"{self.name} Overwrite train_micro_batch_size with generation_batch_size {self.module_args.generation_batch_size}"
                )
            self.train_micro_batch_size = self.module_args.generation_batch_size
        else:
            self.train_micro_batch_size = self.runtime_args.train_micro_batch_size
            self.train_global_batch_size = self.runtime_args.train_global_batch_size

        self.fsdp_size = self.module_args.fsdp_size
        self.device_mesh = None

    def get_visible_gpus(self):
        """
        :meta private:
        """
        return ray.get_gpu_ids()

    @staticmethod
    def init_fn(x: torch.nn.Module):
        if torch.distributed.get_rank() != 0:
            x = x.to_empty(device=torch.cuda.current_device(), recurse=False)
            torch.cuda.empty_cache()
        return x

    @staticmethod
    def get_fsdp_wrap_policy(module:torch.nn.Module, min_num_params:int=0):
        """Get FSDP wrap policy for the module.

        Args:
            module: The module to get wrap policy for
            min_num_params: size based wrap policy min num params
        """

        default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None)
        fsdp_transformer_layer_cls_to_wrap = default_transformer_cls_names_to_wrap
        auto_wrap_policy = None

        policies = []

        # Add lambda policy for LoRA modules if is_lora is True
        # min_num_params must be 0 to use transformer_auto_wrap_policy
        if min_num_params > 0:
            size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params)
            policies.append(size_policy)
        elif fsdp_transformer_layer_cls_to_wrap is not None:
            transformer_cls_to_wrap = set()
            for layer_class in fsdp_transformer_layer_cls_to_wrap:
                transformer_cls = get_module_class_from_name(module, layer_class)
                assert transformer_cls is not None, "Could not find the transformer layer class to wrap in the model."
                transformer_cls_to_wrap.add(transformer_cls)

            transformer_policy = functools.partial(
                transformer_auto_wrap_policy,
                transformer_layer_cls=transformer_cls_to_wrap,
            )
            policies.append(transformer_policy)
        ### hardcode for qwen2.5, fsdp warp for get submodule state dict ###
        def lambda_fn(sub_module: nn.Module):
            if sub_module in [module.model.embed_tokens, module.model.norm, module.lm_head]:
                return True
            return False
        lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_fn)
        policies.append(lambda_policy)
        ## hardcode for qwen2.5, fsdp warp for get submodule state dict ##

        if len(policies) > 0:
            auto_wrap_policy = functools.partial(_or_policy, policies=policies)

        return auto_wrap_policy


    def create_device_mesh(self, world_size, fsdp_size):
        if not self.device_mesh:
            if world_size == fsdp_size:
                self.device_mesh = dist.device_mesh.init_device_mesh(
                    "cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
                )
            else:
                self.device_mesh = dist.device_mesh.init_device_mesh(
                    "cuda",
                    mesh_shape=(fsdp_size, world_size // fsdp_size),
                    mesh_dim_names=["fsdp", "ddp"],
                )
            print(f"world size {world_size}, fsdp_size {fsdp_size}, {self.device_mesh}")

    def setup_distributed(self):
        print(self.get_dist_env())
        if not dist.is_initialized():
            dist.init_process_group(backend="nccl")

        self.create_device_mesh(self.world_size, self.fsdp_size)

    def peak_memory(self):
        """
        :meta private:
        """
        self._peak_memory = max(
            self._peak_memory, torch.cuda.max_memory_allocated() / (1024**3)
        )
        return self._peak_memory

    def empty_cache(self):
        """
        :meta private:
        """
        if not self.timers("empty_cache").started_:
            self.timers("empty_cache").start()
        peak_mem = torch.cuda.max_memory_allocated() / (1024**3)
        debug_rank_0(
            f"{self.name} replica: {self.replica_id}, before empty cache, peak mem: {peak_mem:.2f} GiB",
            self._logger,
        )
        # Manual gc
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        peak_mem = torch.cuda.max_memory_allocated() / (1024**3)
        debug_rank_0(
            f"{self.name} replica: {self.replica_id}, after empty cache, peak mem: {peak_mem:.2f} GiB",
            self._logger,
        )
        self.timers("empty_cache").stop()

    def check_param_exists(self, names):
        """
        check if the given names exists in current model

        :meta private:
        """
        not_exists = []
        for name in names:
            if not self.exist_parameter(name):
                not_exists.append(name)
        if not_exists:
            log_rank_0(
                f"parameters not exists: {not_exists} in model {self.name}",
                self._logger,
            )
            return False
        return True

    def create_model(self, model_path, torch_dtype):
        model = AutoModelForCausalLM.from_pretrained(
            pretrained_model_name_or_path=model_path,
            torch_dtype=torch_dtype,
            attn_implementation="flash_attention_2",
            trust_remote_code=True,
        )
        return model

    @property
    def data_parallel_size(self):
        """
        :meta private:
        """
        return dist.get_world_size()

    @property
    def data_parallel_rank(self):
        """
        :meta private:
        """
        return dist.get_rank()

    def tensor_parallel_rank(self):
        return self.data_parallel_rank

    def pipeline_parallel_rank(self):
        return 1

    def expert_model_parallel_size(self):
        return 1

    def model_setup(self):
        """
        :meta private:
        """
        super().model_setup()
        self.setup_distributed()
        args = dict_to_simplenamespace(self.model_args)
        self.args = args
        model = self.create_model(args.pretrain_or_model, torch_dtype=torch.bfloat16)
        self.tokenizer = AutoTokenizer.from_pretrained(
            args.pretrain_or_model, trust_remote_code=True, use_fast=True
        )
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})
        mix_precision_config = MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
            buffer_dtype=torch.float32,
        )
        sharding_strategy = ShardingStrategy.FULL_SHARD
        auto_wrap_policy = self.get_fsdp_wrap_policy(module = model)
        self.model = FSDP(
            model,
            cpu_offload=None,
            auto_wrap_policy=auto_wrap_policy,
            device_id=torch.cuda.current_device(),
            sharding_strategy=sharding_strategy,
            mixed_precision=mix_precision_config,
            sync_module_states=True,
            param_init_fn=FSDPModule.init_fn,
            device_mesh=self.device_mesh,
            forward_prefetch=False,
        )
        self.model.to(torch.float32)
        FSDP.set_state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT)
        if not self.trainable:
            self.optimizer = None
            self.model.eval()
        else:
            self.optimizer = optim.AdamW(
                self.model.parameters(),
                lr=self.module_args.args_dict.get("learning_rate", 2e-6),
                betas=(0.9, 0.999),
                weight_decay=1e-2
            )

        # resume model weights
        if self.resume_training:
            self.load_checkpoint(self._episode_id)

        self.offload()

    def get_fsdp_param_name(self):
        name_list = []
        for name, _ in self.model.named_parameters():
            parts = name.split('.')
            filtered_parts = [
                part for part in parts
                if part not in {"_fsdp_wrapped_module", "_flat_param"}
            ]
            cleaned_name = '.'.join(filtered_parts)
            name_list.append(cleaned_name)
        return name_list

    def get_weight_ipc_handles_by_name(self, block_name: str):
        """
        get fsdp warpped module weight by name get from named_parameters
        avoid get total model state_dict
        """
        torch.cuda.empty_cache()
        for prefix_name, module in self.model.named_modules():
            prefix_name = prefix_name.replace('_fsdp_wrapped_module.', '')
            if isinstance(module, FSDP) and prefix_name==block_name:
                state_dict = module.state_dict()
                reduce_tensor_dict = {}
                for name, param in state_dict.items():
                    reduce_tensor_dict['.'.join([prefix_name, name])] = reduce_tensor(param.full_tensor())
                return reduce_tensor_dict

    @torch.no_grad()
    def onload_weights(self,empty_cache=True):
        _lazy_init(self.model, self.model)
        assert self.model._is_root
        device_id = torch.cuda.current_device()
        for handle in self.model._all_handles:
            if handle._offload_params:
                continue
            flat_param = handle.flat_param
            handle.flat_param_to(torch.device(f"cuda:{device_id}"), non_blocking=True)
            # the following still keeps id(._local_shard) != id(.data)
            flat_param._local_shard = flat_param.data
        if empty_cache:
            torch.cuda.empty_cache()

    @torch.no_grad()
    def offload_weights(self, empty_cache=True):
        assert isinstance(self.model, FSDP)
        # lazy init FSDP model
        _lazy_init(self.model, self.model)
        assert self.model._is_root, "Only support root model offloading to CPU"
        for handle in self.model._all_handles:
            if handle._offload_params:
                continue
            flat_param = handle.flat_param
            assert (
                flat_param.data.data_ptr() == flat_param._local_shard.data_ptr()
                and id(flat_param.data) != id(flat_param._local_shard)
                and flat_param.data.size() == flat_param._local_shard.size()
            )
            handle.flat_param_to(torch.device("cpu"), non_blocking=True)
            # Explicit call to free unshard flat param
            handle._free_unsharded_flat_param()
            # the following still keeps id(._local_shard) != id(.data)
            flat_param._local_shard = flat_param.data
            assert id(flat_param._local_shard) != id(flat_param.data)

        # Explicit releas ipc handles
        torch.cuda.ipc_collect()
        if empty_cache:
            torch.cuda.empty_cache()

    @torch.no_grad()
    def offload_optimizer_states(self, empty_cache=True):
        if not self.optimizer.state:
            return
        for param_group in self.optimizer.param_groups:
            for param in param_group["params"]:
                state = self.optimizer.state[param]
                for key, value in state.items():
                    if isinstance(value, torch.Tensor):
                        state[key] = value.to("cpu", non_blocking=True)

        if empty_cache:
            torch.cuda.empty_cache()

    @torch.no_grad()
    def onload_optimizer_states(self, empty_cache=True):
        if not self.optimizer.state:
            return
        device = torch.cuda.current_device()
        for param_group in self.optimizer.param_groups:
            for param in param_group["params"]:
                state = self.optimizer.state[param]
                for key, value in state.items():
                    if isinstance(value, torch.Tensor):
                        state[key] = value.to(device, non_blocking=True)

        if empty_cache:
            torch.cuda.empty_cache()

    def save_checkpoint(self, iteration):
        save_dir = f"{self.runtime_args.output_dir}/save_model/{self.name}/{iteration}"
        if self.data_parallel_rank == 0 and not os.path.exists(save_dir):
            os.makedirs(save_dir)

        state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
        optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
        with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
            model_state_dict = self.model.state_dict()
            optimizer_state_dict = self.optimizer.state_dict() if self.optimizer is not None else None
            # lr_scheduler_state_dict = self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None
            extra_state_dict = {
                # "lr_scheduler": lr_scheduler_state_dict,
                "rng": self.get_rng_state(),
            }
            model_path = os.path.join(save_dir, f"model_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt")
            optim_path = os.path.join(save_dir, f"optim_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt")
            extra_path = os.path.join(save_dir, f"extra_state_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt")
            torch.save(model_state_dict, model_path)
            torch.save(optimizer_state_dict, optim_path)
            torch.save(extra_state_dict, extra_path)
        torch.distributed.barrier()

        # save for hf format
        if self.model_args.get("save_hf", True):
            state_dict_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            with FSDP.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, state_dict_cfg, None):
                model_state_dict = self.model.state_dict()
                if self.data_parallel_rank == 0:
                    hf_path = os.path.join(save_dir, "huggingface")
                    os.makedirs(hf_path, exist_ok=True)
                    model_config = self.model._fsdp_wrapped_module.config
                    model_config.save_pretrained(hf_path)
                    self.tokenizer.save_pretrained(hf_path)

                    with torch.device("meta"):
                        save_model = AutoModelForCausalLM.from_config(model_config, torch_dtype=torch.bfloat16)
                    save_model.to_empty(device="cpu")
                    save_model.save_pretrained(hf_path, state_dict=model_state_dict)
            self._logger.info(f"save checkpoint to {save_dir}")

    def load_checkpoint(self, iteration):
        load_dir = f"{self.runtime_args.output_dir}/save_model/{self.name}/{iteration}"
        if not os.path.exists(load_dir):
            self._logger.info(f"{load_dir} not exists, will skip load")
            return
        model_path = os.path.join(load_dir, f"model_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt")
        optim_path = os.path.join(load_dir, f"optim_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt")
        extra_state_path = os.path.join(load_dir, f"extra_state_world_size_{self.data_parallel_size}_rank_{self.data_parallel_rank}.pt")

        model_state_dict = torch.load(model_path, weights_only=False)
        optimizer_state_dict = torch.load(optim_path, weights_only=False)
        extra_state_dict = torch.load(extra_state_path, weights_only=False)

        state_dict_cfg = ShardedStateDictConfig(offload_to_cpu=True)
        optim_cfg = ShardedOptimStateDictConfig(offload_to_cpu=True)
        with FSDP.state_dict_type(self.model, StateDictType.SHARDED_STATE_DICT, state_dict_cfg, optim_cfg):
            self.model.load_state_dict(model_state_dict)
            if self.optimizer is not None:
                self.optimizer.load_state_dict(optimizer_state_dict)
        # recover random state
        if "rng" in extra_state_dict:
            # 'rng' may not exist for backward compatibility
            self.load_rng_state(extra_state_dict["rng"])
        torch.distributed.barrier()


    @staticmethod
    def get_rng_state():
        rng_state = {
            "cpu": torch.get_rng_state(),
            "cuda": torch.cuda.get_rng_state(),
            "numpy": np.random.get_state(),
            "random": random.getstate(),
        }
        return rng_state

    @staticmethod
    def load_rng_state(rng_state):
        torch.set_rng_state(rng_state["cpu"])
        torch.cuda.set_rng_state(rng_state["cuda"])
        np.random.set_state(rng_state["numpy"])
        random.setstate(rng_state["random"])
    