# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/aws-neuron/neuronx-distributed-inference/blob/9993358ce052fd7a1bb4a7497a6318aac36ed95c/src/neuronx_distributed_inference/models/config.py
from typing import List, Optional, Union

import torch

from ....configuration_utils import NeuronConfig, register_neuron_config
from ....utils import map_torch_dtype


NEURON_CONFIG_FILE = "neuron_config.json"


def to_dict(obj):
    if type(obj) is dict:
        return {k: to_dict(v) for k, v in obj.items()}
    elif type(obj) is list:
        return [to_dict(v) for v in obj]
    elif hasattr(obj, "__dict__"):
        return {k: to_dict(v) for k, v in obj.__dict__.items()}
    elif type(obj) is torch.dtype:
        return str(obj).split(".")[1]
    else:
        return obj


class IncompatibleConfigError(ValueError):
    pass


@register_neuron_config
class NxDNeuronConfig(NeuronConfig):
    """
    Base config class for inference in NxD.

    This class contains attributes that are needed for various inference
    optimization/features in NxD.

    These attributes are a subset of the attributes in the original NxDI config class.
    """

    def __init__(
        self,
        checkpoint_id: str = None,
        checkpoint_revision: str = None,
        batch_size: Optional[int] = 1,
        max_batch_size: Optional[int] = None,
        continuous_batching: Optional[bool] = False,
        speculation_length: Optional[int] = 0,
        sequence_length: Optional[int] = 128,
        tp_degree: Optional[int] = 1,
        ep_degree: Optional[int] = 1,
        pp_degree: Optional[int] = 1,
        torch_dtype: Optional[Union[str, torch.dtype]] = torch.bfloat16,
        rpl_reduce_dtype: Optional[Union[str, torch.dtype]] = None,
        n_active_tokens: Optional[int] = None,
        max_context_length: Optional[int] = None,
        output_logits: Optional[bool] = False,
        padding_side: Optional[str] = "right",
        fused_qkv: Optional[bool] = False,
        vocab_parallel: Optional[bool] = False,
        sequence_parallel_enabled: Optional[bool] = False,
        is_chunked_prefill: Optional[bool] = False,
        flash_decoding_enabled: Optional[bool] = False,
        async_mode: Optional[bool] = False,
        qk_layernorm: Optional[bool] = False,
        attn_kernel_enabled: Optional[bool] = False,
        qkv_kernel_enabled: Optional[bool] = False,
        mlp_kernel_enabled: Optional[bool] = False,
        mlp_kernel_fuse_residual_add: Optional[bool] = False,
        enable_bucketing: Optional[bool] = False,
        target: Optional[str] = None,  # Set to "trn2" for trn2
        logical_nc_config: Optional[int] = 1,
        cc_pipeline_tiling_factor: Optional[int] = 2,
        num_cores_per_group: Optional[int] = 1,
        on_device_sampling: Optional[bool] = False,
        max_topk: Optional[int] = 256,
        start_rank_id: Optional[int] = 0,
        local_ranks_size: Optional[int] = None,
        capacity_factor: float = None,
        glu_mlp: bool = True,
    ) -> None:
        # TODO: these flags are suposed to work in NxDI. Either make them work or remove them
        if is_chunked_prefill:
            raise ValueError("`is_chunked_prefill` is not supported in optimum-neuron.")
        if flash_decoding_enabled:
            raise ValueError("`flash_decoding_enabled` is not supported in optimum-neuron.")
        if async_mode:
            raise ValueError("`async_mode` is not supported in optimum-neuron.")
        if qkv_kernel_enabled or mlp_kernel_enabled:
            raise ValueError("`qkv_kernel_enabled` and `mlp_kernel_enabled` are not supported for trn1 chips.")
        if vocab_parallel:
            raise ValueError("`vocab_parallel` is not supported in optimum-neuron.")
        if qk_layernorm:
            raise ValueError(
                "`qk_layernorm` is not supported in optimum-neuron. It is actually a modeling flag that affects the attention layer."
            )
        # Required to retrieve a checkpoint from the hub
        self.checkpoint_id = checkpoint_id
        self.checkpoint_revision = checkpoint_revision
        # Basic config for inference in NxD
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.tp_degree = tp_degree
        self.torch_dtype = torch_dtype
        if isinstance(self.torch_dtype, str):
            self.torch_dtype = map_torch_dtype(self.torch_dtype)
        self.n_active_tokens = self.sequence_length if n_active_tokens is None else n_active_tokens
        self.output_logits = output_logits

        self.padding_side = padding_side

        self.rpl_reduce_dtype = torch_dtype if rpl_reduce_dtype is None else rpl_reduce_dtype
        if isinstance(self.rpl_reduce_dtype, str):
            self.rpl_reduce_dtype = map_torch_dtype(self.rpl_reduce_dtype)

        # fallback to sequence_length is for compatibility with vllm
        self.max_context_length = max_context_length
        if self.max_context_length is None:
            self.max_context_length = sequence_length

        # Graph transforms
        self.fused_qkv = fused_qkv

        # Functional parallelism
        self.vocab_parallel = vocab_parallel
        self.sequence_parallel_enabled = sequence_parallel_enabled
        self.is_chunked_prefill = is_chunked_prefill

        # Continuous batching
        # TODO: Check if we really need different batch size for CTE and TKG, given
        # that we anyway provide two different config instance for them.
        self.continuous_batching = continuous_batching
        self.max_batch_size = batch_size if max_batch_size is None else max_batch_size

        # On-device sampling
        self.on_device_sampling = on_device_sampling
        self.max_topk = max_topk

        # async
        self.async_mode = async_mode

        # Bucketing
        self.enable_bucketing = enable_bucketing

        # Speculative decoding
        self.speculation_length = speculation_length
        if self.speculation_length > 0:
            if self.async_mode:
                raise IncompatibleConfigError("Speculative Decoding is not yet supported with async.")
            if self.on_device_sampling:
                raise IncompatibleConfigError("Speculative decoding is incompatible with on-device sampling")

        # Distributed config
        self.pp_degree = pp_degree
        self.ep_degree = ep_degree

        # QK layer normalization
        self.qk_layernorm = qk_layernorm

        # Multi-node
        # TODO: Check if start_rank_id can be modified dynamically at runtime
        # Otherwise, we need multiple exports for different start_rank_id
        self.start_rank_id = start_rank_id
        self.local_ranks_size = local_ranks_size
        if self.local_ranks_size is None:
            self.local_ranks_size = self.world_size

        # Flash decoding
        self.flash_decoding_enabled = flash_decoding_enabled
        self.num_cores_per_group = num_cores_per_group

        # Kernels
        self.attn_kernel_enabled = attn_kernel_enabled
        self.qkv_kernel_enabled = qkv_kernel_enabled
        self.mlp_kernel_enabled = mlp_kernel_enabled
        self.mlp_kernel_fuse_residual_add = mlp_kernel_fuse_residual_add

        # compiler flags
        self.logical_nc_config = logical_nc_config
        self.cc_pipeline_tiling_factor = cc_pipeline_tiling_factor
        self.target = target

        # MoE specific
        self.capacity_factor = float(capacity_factor) if capacity_factor is not None else None
        self.glu_mlp = glu_mlp

    @property
    def ctx_batch_size(self) -> int:
        return 1 if self.continuous_batching else self.batch_size

    @property
    def tkg_batch_size(self) -> int:
        return self.batch_size

    @property
    def world_size(self) -> int:
        """
        The total number of ranks in the distributed setup.
        """
        return self.tp_degree * self.pp_degree * self.ep_degree

    @property
    def weights_to_skip_layout_optimization(self) -> List[str]:
        """
        List of weights to skip layout optimization.

        Can be overridden by subclasses to specify weights that should not be optimized.
        """
        return []
