# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from functools import partial

import torch

from megatron.core.tensor_parallel import gather_from_sequence_parallel_region
from megatron.core.transformer.moe.moe_utils import (
    sequence_load_balancing_loss_func,
    switch_load_balancing_loss_func,
)

from megatron.core.transformer.moe.router import TopKRouter as _TopKRuter

from .moe_utils import topk_softmax_with_capacity

class TopKRouter(_TopKRuter):
    """Route each token to the top-k experts."""

    def compute_routing_scores_for_aux_loss(self, logits: torch.Tensor) -> torch.Tensor:
        """Compute routing scores based on the score function.

        Args:
            logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].

        Returns:
            torch.Tensor: The normalized routing scores.
        """
        if self.score_function == "softmax":
            scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
            # NOTE: hard code norm_topk_prob here
            scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
        elif self.score_function == "sigmoid":
            scores = torch.sigmoid(logits)
            scores = (
                scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.topk > 1 else scores
            )
        else:
            raise ValueError(f"Invalid score_function: {self.score_function}")
        return scores

    def aux_loss_load_balancing(self, logits: torch.Tensor):
        """Apply auxiliary loss-based load balancing to the logits tensor.

        Args:
            logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].

        Returns:
            probs (torch.Tensor): The probabilities of token to experts assignment.
            routing_map (torch.Tensor): The mask of token to experts assignment.
        """
        probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
            logits,
            self.topk,
            capacity_factor=self.config.moe_expert_capacity_factor,
            pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
            drop_policy=self.config.moe_token_drop_policy,
            use_pre_softmax=self.config.moe_router_pre_softmax,
            num_groups=self.config.moe_router_num_groups,
            group_topk=self.config.moe_router_group_topk,
            scaling_factor=self.config.moe_router_topk_scaling_factor,
            deterministic_mode=self.config.deterministic_mode,
            score_function=self.score_function,
            expert_bias=self.expert_bias,
        )

        if self.training and torch.is_grad_enabled():
            # Apply auxiliary load balancing loss
            # Skip auxiliary loss calculations when using torch.no_grad() or checkpointing.
            scores = self.compute_routing_scores_for_aux_loss(logits)
            aux_loss_func = partial(
                switch_load_balancing_loss_func,
                probs=scores,
                tokens_per_expert=tokens_per_expert,
                topk=self.topk,
            )
            probs = self.apply_load_balancing_loss(
                activation=probs, load_balancing_loss_func=aux_loss_func
            )
        return probs, routing_map

    def seq_aux_loss_load_balancing(self, logits: torch.Tensor, bsz: int, seq_length: int):
        """Apply sequence-auxiliary loss-based load balancing to the logits tensor.

        Args:
            logits (torch.Tensor): The logits tensor after gating, shape: [num_tokens, num_experts].
            bsz (int): The batch size.
            seq_length (int): The sequence length.

        Returns:
            probs (torch.Tensor): The probabilities of token to experts assignment.
            routing_map (torch.Tensor): The mask of token to experts assignment.
        """

        probs, routing_map, tokens_per_expert = topk_softmax_with_capacity(
            logits,
            self.topk,
            capacity_factor=self.config.moe_expert_capacity_factor,
            pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
            drop_policy=self.config.moe_token_drop_policy,
            use_pre_softmax=self.config.moe_router_pre_softmax,
            num_groups=self.config.moe_router_num_groups,
            group_topk=self.config.moe_router_group_topk,
            scaling_factor=self.config.moe_router_topk_scaling_factor,
            deterministic_mode=self.config.deterministic_mode,
            score_function=self.score_function,
            expert_bias=self.expert_bias,
        )

        if self.training and torch.is_grad_enabled():
            # Apply sequence-auxiliary load balancing loss
            scores = self.compute_routing_scores_for_aux_loss(logits)
            aux_loss_func = partial(
                sequence_load_balancing_loss_func,
                probs=scores,
                routing_map=routing_map,
                batch_size=bsz,
                seq_length=seq_length,
                topk=self.topk,
            )
            probs = self.apply_load_balancing_loss(
                activation=probs, load_balancing_loss_func=aux_loss_func
            )

        return probs, routing_map

    def routing(self, logits: torch.Tensor):
        """Top-k routing function

        Args:
            logits (torch.Tensor): Logits tensor after gating.

        Returns:
            probs (torch.Tensor): The probabilities of token to experts assignment.
            routing_map (torch.Tensor): The mapping of token to experts assignment,
                with shape [num_tokens, num_experts].
        """
        seq_length, bsz = logits.shape[:2]
        logits = logits.view(-1, self.config.num_moe_experts)

        # Apply Z-Loss
        logits = self.apply_z_loss(logits)

        if self.config.moe_token_dispatcher_type == "alltoall_seq":
            # Gather the logits from the TP region
            logits = gather_from_sequence_parallel_region(logits)

        if self.routing_type == "sinkhorn":
            scores, routing_map = self.sinkhorn_load_balancing(logits)
        elif self.routing_type == "aux_loss":
            scores, routing_map = self.aux_loss_load_balancing(logits)
        elif self.routing_type == "seq_aux_loss":
            scores, routing_map = self.seq_aux_loss_load_balancing(logits, bsz, seq_length)
        elif self.routing_type == "none":
            # A naive top-k routing without load balancing
            scores, routing_map, _ = topk_softmax_with_capacity(
                logits,
                self.topk,
                capacity_factor=self.config.moe_expert_capacity_factor,
                pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
                drop_policy=self.config.moe_token_drop_policy,
                use_pre_softmax=self.config.moe_router_pre_softmax,
                num_groups=self.config.moe_router_num_groups,
                group_topk=self.config.moe_router_group_topk,
                scaling_factor=self.config.moe_router_topk_scaling_factor,
                deterministic_mode=self.config.deterministic_mode,
                score_function=self.score_function,
                expert_bias=self.expert_bias,
            )
        else:
            raise ValueError(f"Unsupported MoE routing type: {self.routing_type}")
        # Prevent extra local tokens accumulation on evaluation or activation recomputation
        if self.enable_expert_bias and torch.is_grad_enabled():
            with torch.no_grad():
                self.local_tokens_per_expert += routing_map.sum(dim=0)

        return scores, routing_map

