from typing import Optional

import torch
import torch.nn as nn
import os

from text_generation_server.utils.weights import Weights
from text_generation_server.layers.fp8 import (
    Fp8Weight,
    fp8_quantize,
    quant_dtype,
    normalize_e4m3fn_to_native_float8,
    dynamic_quant,
    dequant_block_fp8_weight_naive,
)
from text_generation_server.layers.moe.fused_moe import select_experts
import habana_frameworks.torch as htorch


class FP8SparseMoELayer(nn.Module):
    def __init__(
        self,
        *,
        n_expert_group: Optional[int],
        n_experts: int,
        prefix: str,
        renormalize: bool,
        topk: int,
        topk_group: Optional[int],
        weights: Weights,
        scoring_func: Optional[str] = "softmax",
        e_score_correction_bias: Optional[float] = None,
        gate_proj_name: str = "gate_proj",
        up_proj_name: str = "up_proj",
        down_proj_name: str = "down_proj",
    ):
        super().__init__()

        assert (n_expert_group is None) == (
            topk_group is None
        ), "n_expert_group and topk_group must both be None or have some value"

        self.n_expert_group = n_expert_group
        self.topk = topk
        self.topk_group = topk_group
        self.renormalize = renormalize
        self.weight_block_size = weights.weights_loader.weight_block_size
        self.scoring_func = scoring_func
        self.e_score_correction_bias = e_score_correction_bias
        self.world_size = weights.process_group.size()
        self.rank = weights.process_group.rank()
        self.ep_rank = self.rank
        self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"

        if self.use_ep:
            n_experts = (n_experts + self.world_size - 1) // self.world_size
            self.ep_offset = self.ep_rank * n_experts
        else:
            self.ep_offset = 0

        (
            self.gate_up_proj,
            self.gate_up_proj_weight_scale,
            self.gate_up_proj_input_scale,
        ) = _load_expert_multi_weights_col(
            prefix=prefix,
            n_experts=n_experts,
            gate_proj_name=gate_proj_name,
            up_proj_name=up_proj_name,
            weights=weights,
            use_ep=self.use_ep,
            ep_offset=self.ep_offset,
        )

        self.down_proj, self.down_proj_weight_scale, self.down_proj_input_scale = (
            _load_expert_weights_row(
                prefix=prefix,
                n_experts=n_experts,
                name=down_proj_name,
                weights=weights,
                use_ep=self.use_ep,
                ep_offset=self.ep_offset,
            )
        )
        if self.weight_block_size is not None:
            self.gate_up_proj, self.gate_up_proj_weight_scale = dynamic_quant(
                dequant_block_fp8_weight_naive(
                    self.gate_up_proj,
                    self.gate_up_proj_weight_scale,
                    self.weight_block_size,
                )
            )
            self.down_proj, self.down_proj_weight_scale = dynamic_quant(
                dequant_block_fp8_weight_naive(
                    self.down_proj, self.down_proj_weight_scale, self.weight_block_size
                )
            )
            self.gate_up_proj_weight_scale, self.down_proj_weight_scale = (
                self.gate_up_proj_weight_scale.squeeze(-1),
                self.down_proj_weight_scale.squeeze(-1),
            )

    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
        topk_weights, topk_ids = select_experts(
            hidden_states=x,
            router_logits=gating_output,
            use_grouped_topk=self.n_expert_group is not None,
            top_k=self.topk,
            renormalize=self.renormalize,
            topk_group=self.topk_group,
            num_expert_group=self.n_expert_group,
            scoring_func=self.scoring_func,
            e_score_correction_bias=self.e_score_correction_bias,
        )
        total_num_experts = gating_output.size(-1)
        x_fp8, x_scale = dynamic_quant(x, single_scale=True)

        if self.use_ep:
            moe_n_slice = 1
            n_expert_slice = (
                total_num_experts + self.world_size - 1
            ) // self.world_size
        else:
            moe_n_slice = 1
            n_expert_slice = (total_num_experts + moe_n_slice - 1) // moe_n_slice
        for i in range(moe_n_slice):
            min_expert = i * n_expert_slice
            max_expert = min((i + 1) * n_expert_slice, total_num_experts)
            w13_list_slice = [
                self.gate_up_proj[j, ...] for j in range(min_expert, max_expert)
            ]
            w2_list_slice = [
                self.down_proj[j, ...] for j in range(min_expert, max_expert)
            ]
            w13_weight_scale = [
                self.gate_up_proj_weight_scale[j, ...]
                for j in range(min_expert, max_expert)
            ]
            w2_weight_scale = [
                self.down_proj_weight_scale[j, ...]
                for j in range(min_expert, max_expert)
            ]

            current_hidden_states = torch.ops.hpu.mixture_of_experts(
                hidden_states=x_fp8,
                expert_routing_table=topk_ids.to(torch.int64),
                router_weights=topk_weights.to(x.dtype),
                w12=w13_list_slice,
                w3=w2_list_slice,
                d_scale_hidden_states=x_scale,
                d_scale_w12=w13_weight_scale,
                d_scale_w3=w2_weight_scale,
                permuted_weights=True,
                activation="silu",
                experts_min=min_expert + self.ep_offset,
                experts_max=max_expert + self.ep_offset - 1,
            )
            htorch.core.mark_step()
            if i == 0:
                final_hidden_states = current_hidden_states
            else:
                final_hidden_states.add_(current_hidden_states)
        return final_hidden_states


def _load_expert_weights(
    get_weight_fn,
    *,
    prefix: str,
    n_experts: int,
    name: str,
    weights: Weights,
    ep_offset: int = 0,
) -> torch.Tensor:
    all_weight = None
    all_weight_scales = None
    max_input_scale = None

    for i in range(n_experts):
        weight = get_weight_fn(prefix, i + ep_offset, name, weights)

        assert isinstance(weight, Fp8Weight)

        if all_weight is None:
            all_weight = torch.empty(
                (n_experts,) + weight.weight.shape,
                dtype=quant_dtype,
                device=weight.weight.device,
            )
        if all_weight_scales is None:
            all_weight_scales = torch.empty(
                (n_experts,) + weight.weight_scale.shape,
                dtype=torch.float32,
                device=weight.weight.device,
            )

        if weight.weight.dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz}:
            all_weight[i], all_weight_scales[i], current_input_scale = (
                normalize_e4m3fn_to_native_float8(
                    weight.weight, weight.weight_scale, weight.input_scale
                )
            )
            if current_input_scale is not None:
                if max_input_scale is None or current_input_scale > max_input_scale:
                    max_input_scale = current_input_scale
        else:
            all_weight[i], all_weight_scales[i] = fp8_quantize(
                weight.weight, scalar=True
            )

    assert all_weight is not None

    return all_weight, all_weight_scales, max_input_scale


def _load_expert_multi_weights_col(
    *,
    prefix: str,
    n_experts: int,
    gate_proj_name: str,
    up_proj_name: str,
    weights: Weights,
    use_ep: bool = False,
    ep_offset: int = 0,
) -> torch.Tensor:
    def get_weight_fn_sharded(prefix, i, name, weights):
        return weights.get_multi_weights_col(
            [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
        )

    def get_weight_fn(prefix, i, name, weights):
        return weights.get_multi_weights(
            [f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
        )

    return _load_expert_weights(
        get_weight_fn if use_ep else get_weight_fn_sharded,
        prefix=prefix,
        n_experts=n_experts,
        name=None,
        weights=weights,
        ep_offset=ep_offset if use_ep else 0,
    )


def _load_expert_weights_row(
    *,
    prefix: str,
    n_experts: int,
    name: str,
    weights: Weights,
    use_ep: bool = False,
    ep_offset: int = 0,
) -> torch.Tensor:
    def get_weight_fn_sharded(prefix, i, name, weights):
        return weights.get_weights_row(f"{prefix}.{i}.{name}")

    def get_weight_fn(prefix, i, name, weights):
        return weights.get_weights(f"{prefix}.{i}.{name}")

    return _load_expert_weights(
        get_weight_fn if use_ep else get_weight_fn_sharded,
        prefix=prefix,
        n_experts=n_experts,
        name=name,
        weights=weights,
        ep_offset=ep_offset if use_ep else 0,
    )
