backends/gaudi/server/text_generation_server/layers/moe/__init__.py [29:216]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@runtime_checkable
class MoELayer(Protocol):
    def __init__(
        self,
        *,
        n_expert_group: Optional[int],
        n_experts: int,
        prefix: str,
        renormalize: bool,
        topk: int,
        topk_group: Optional[int],
        weights: Weights,
        gate_proj_name: str = "gate_proj",
        up_proj_name: str = "up_proj",
        down_proj_name: str = "down_proj",
        hidden_act: str = "silu",
        scoring_func: Optional[str] = None,
        e_score_correction_bias: Optional[float] = None,
    ): ...

    def forward(
        self, x: torch.Tensor, *, gating_output: torch.Tensor
    ) -> torch.Tensor: ...


class DenseMoELayer(nn.Module):
    """
    Layer for MoE that applies *all* experts to each tokens and then weights
    their outputs based on the calculated routing. This layer is much slower
    than `SparseMoELayer` and should only be used when no fused kernels are
    available (e.g. for unsupported quantizers).
    """

    def __init__(
        self,
        *,
        n_expert_group: Optional[int],
        n_experts: int,
        prefix: str,
        renormalize: bool,
        topk: int,
        topk_group: Optional[int],
        weights: Weights,
        gate_proj_name: str = "gate_proj",
        up_proj_name: str = "up_proj",
        down_proj_name: str = "down_proj",
        hidden_act: str = "silu",
        scoring_func: Optional[str] = None,
        e_score_correction_bias: Optional[float] = None,
    ):
        super().__init__()

        assert scoring_func is None, "scoring func is not handled"
        assert e_score_correction_bias is None, "scoring correction bias is not handled"

        log_once(
            logger.info,
            "No fused layers are available for this model type, using (slower) dense MoE layer",
        )

        assert (n_expert_group is None) == (
            topk_group is None
        ), "n_expert_group and topk_group must both be None or have some value"

        self.n_expert_group = n_expert_group
        self.n_experts = n_experts
        self.renormalize = renormalize
        self.topk = topk
        self.topk_group = topk_group

        if "gelu" in hidden_act:
            self.act = lambda x: torch.nn.functional.gelu(
                x,
                approximate=(
                    "tanh"
                    if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
                    else "none"
                ),
            )
        elif "silu" in hidden_act:
            self.act = torch.nn.functional.silu
        else:
            self.act = ACT2FN[hidden_act]

        self.gate_proj = [
            TensorParallelColumnLinear.load(
                None,
                prefix=f"{prefix}.{i}.{gate_proj_name}",
                weights=weights,
                bias=False,
            )
            for i in range(self.n_experts)
        ]
        self.up_proj = [
            TensorParallelColumnLinear.load(
                None,
                prefix=f"{prefix}.{i}.{up_proj_name}",
                weights=weights,
                bias=False,
            )
            for i in range(self.n_experts)
        ]
        self.down_proj = [
            TensorParallelRowLinear.load(
                None,
                prefix=f"{prefix}.{i}.{down_proj_name}",
                weights=weights,
                bias=False,
            )
            for i in range(self.n_experts)
        ]

        self.process_group = weights.process_group

    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
        """
        x: (sequence_length, model_dim)
        gating_output: (sequence_length, n_experts)
        """
        # optional reshape
        input_shape = x.shape
        x = x.view(-1, input_shape[-1])

        if self.n_expert_group is not None and self.topk_group is not None:
            topk_weights, topk_ids = grouped_topk(
                x,
                gating_output,
                self.topk,
                renormalize=self.renormalize,
                num_expert_group=self.n_expert_group,
                topk_group=self.topk_group,
            )
        else:
            topk_weights, topk_ids = fused_topk(
                x, gating_output, self.topk, self.renormalize
            )
            topk_weights = topk_weights.to(x.dtype)

        weights = torch.zeros(
            topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
        )

        weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))

        out = torch.zeros_like(x)
        for i in range(self.n_experts):
            h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
            h = self.down_proj[i](h, reduce=False)
            out += h * weights[:, i].view(-1, 1)

        return out


class SparseMoELayer(nn.Module):
    """
    Layer for MoE that uses fused kernels to only apply the active experts
    for each token (rather than applying all experts and selecting the
    outputs of active experts).
    """

    def __init__(
        self,
        *,
        n_expert_group: Optional[int],
        n_experts: int,
        prefix: str,
        renormalize: bool,
        topk: int,
        topk_group: Optional[int],
        weights: Weights,
        scoring_func: Optional[str] = "softmax",
        e_score_correction_bias: Optional[float] = None,
        gate_proj_name: str = "gate_proj",
        up_proj_name: str = "up_proj",
        down_proj_name: str = "down_proj",
    ):
        super().__init__()
        if (
            isinstance(weights.loader, DefaultWeightsLoader)
            and isinstance(weights.loader.weight_class, UnquantizedWeight)
        ) or isinstance(weights.loader, HybridFP8UnquantLoader):
            if (
                isinstance(weights.loader, HybridFP8UnquantLoader)
                and weights.loader.to_fp8
            ):
                cls = FP8SparseMoELayer
            else:
                cls = UnquantizedSparseMoELayer
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



server/text_generation_server/layers/moe/__init__.py [44:231]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
@runtime_checkable
class MoELayer(Protocol):
    def __init__(
        self,
        *,
        n_expert_group: Optional[int],
        n_experts: int,
        prefix: str,
        renormalize: bool,
        topk: int,
        topk_group: Optional[int],
        weights: Weights,
        gate_proj_name: str = "gate_proj",
        up_proj_name: str = "up_proj",
        down_proj_name: str = "down_proj",
        hidden_act: str = "silu",
        scoring_func: Optional[str] = None,
        e_score_correction_bias: Optional[float] = None,
    ): ...

    def forward(
        self, x: torch.Tensor, *, gating_output: torch.Tensor
    ) -> torch.Tensor: ...


class DenseMoELayer(nn.Module):
    """
    Layer for MoE that applies *all* experts to each tokens and then weights
    their outputs based on the calculated routing. This layer is much slower
    than `SparseMoELayer` and should only be used when no fused kernels are
    available (e.g. for unsupported quantizers).
    """

    def __init__(
        self,
        *,
        n_expert_group: Optional[int],
        n_experts: int,
        prefix: str,
        renormalize: bool,
        topk: int,
        topk_group: Optional[int],
        weights: Weights,
        gate_proj_name: str = "gate_proj",
        up_proj_name: str = "up_proj",
        down_proj_name: str = "down_proj",
        hidden_act: str = "silu",
        scoring_func: Optional[str] = None,
        e_score_correction_bias: Optional[float] = None,
    ):
        super().__init__()

        assert scoring_func is None, "scoring func is not handled"
        assert e_score_correction_bias is None, "scoring correction bias is not handled"

        log_once(
            logger.info,
            "No fused layers are available for this model type, using (slower) dense MoE layer",
        )

        assert (n_expert_group is None) == (
            topk_group is None
        ), "n_expert_group and topk_group must both be None or have some value"

        self.n_expert_group = n_expert_group
        self.n_experts = n_experts
        self.renormalize = renormalize
        self.topk = topk
        self.topk_group = topk_group

        if "gelu" in hidden_act:
            self.act = lambda x: torch.nn.functional.gelu(
                x,
                approximate=(
                    "tanh"
                    if hidden_act in ["gelu_fast", "gelu_pytorch_tanh"]
                    else "none"
                ),
            )
        elif "silu" in hidden_act:
            self.act = torch.nn.functional.silu
        else:
            self.act = ACT2FN[hidden_act]

        self.gate_proj = [
            TensorParallelColumnLinear.load(
                None,
                prefix=f"{prefix}.{i}.{gate_proj_name}",
                weights=weights,
                bias=False,
            )
            for i in range(self.n_experts)
        ]
        self.up_proj = [
            TensorParallelColumnLinear.load(
                None,
                prefix=f"{prefix}.{i}.{up_proj_name}",
                weights=weights,
                bias=False,
            )
            for i in range(self.n_experts)
        ]
        self.down_proj = [
            TensorParallelRowLinear.load(
                None,
                prefix=f"{prefix}.{i}.{down_proj_name}",
                weights=weights,
                bias=False,
            )
            for i in range(self.n_experts)
        ]

        self.process_group = weights.process_group

    def forward(self, x: torch.Tensor, *, gating_output: torch.Tensor) -> torch.Tensor:
        """
        x: (sequence_length, model_dim)
        gating_output: (sequence_length, n_experts)
        """
        # optional reshape
        input_shape = x.shape
        x = x.view(-1, input_shape[-1])

        if self.n_expert_group is not None and self.topk_group is not None:
            topk_weights, topk_ids = grouped_topk(
                x,
                gating_output,
                self.topk,
                renormalize=self.renormalize,
                num_expert_group=self.n_expert_group,
                topk_group=self.topk_group,
            )
        else:
            topk_weights, topk_ids = fused_topk(
                x, gating_output, self.topk, self.renormalize
            )
            topk_weights = topk_weights.to(x.dtype)

        weights = torch.zeros(
            topk_ids.shape[0], self.n_experts, dtype=x.dtype, device=x.device
        )

        weights.scatter_(1, topk_ids.long(), topk_weights.to(weights.dtype))

        out = torch.zeros_like(x)
        for i in range(self.n_experts):
            h = self.act(self.gate_proj[i](x)) * self.up_proj[i](x)
            h = self.down_proj[i](h, reduce=False)
            out += h * weights[:, i].view(-1, 1)

        return out


class SparseMoELayer(nn.Module):
    """
    Layer for MoE that uses fused kernels to only apply the active experts
    for each token (rather than applying all experts and selecting the
    outputs of active experts).
    """

    def __init__(
        self,
        *,
        n_expert_group: Optional[int],
        n_experts: int,
        prefix: str,
        renormalize: bool,
        topk: int,
        topk_group: Optional[int],
        weights: Weights,
        scoring_func: Optional[str] = "softmax",
        e_score_correction_bias: Optional[float] = None,
        gate_proj_name: str = "gate_proj",
        up_proj_name: str = "up_proj",
        down_proj_name: str = "down_proj",
    ):
        super().__init__()
        if (
            isinstance(weights.loader, DefaultWeightsLoader)
            and isinstance(weights.loader.weight_class, UnquantizedWeight)
        ) or isinstance(weights.loader, HybridFP8UnquantLoader):
            if (
                isinstance(weights.loader, HybridFP8UnquantLoader)
                and weights.loader.to_fp8
            ):
                cls = FP8SparseMoELayer
            else:
                cls = UnquantizedSparseMoELayer
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



