# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
from typing import Any, Optional

import torch
from torch import nn

from tzrec.ops import Kernel


class BaseModule(nn.Module, abc.ABC):
    """TorchEasyRec Base Module.

    Args:
        is_inference (bool): is inference or not.
        kernel (Optional[Kernel]): module kernel type.
    """

    def __init__(
        self,
        is_inference: bool = False,
        kernel: Optional[Kernel] = None,
    ) -> None:
        super().__init__()
        self._is_inference = is_inference
        self._kernel = kernel

    def kernel(self) -> Kernel:
        """Get kernel type."""
        kernel = self._kernel
        if kernel is not None:
            return kernel
        else:
            return Kernel.TRITON

    # pyre-ignore [2]
    def recursive_setattr(self, name: str, value: Any) -> None:
        """Recursive set sub module attrs."""
        for _, module in self.named_modules():
            if hasattr(module, name):
                setattr(module, name, value)

    def set_is_inference(self, is_inference: bool) -> None:
        """Set module in inference or not."""
        self._is_inference = is_inference
        self.recursive_setattr("_is_inference", is_inference)

    def set_kernel(self, kernel: Kernel) -> None:
        """Set module kernel type."""
        self._kernel = kernel
        self.recursive_setattr("_kernel", kernel)

    @property
    def is_inference(self) -> bool:
        """Get module is inference or not."""
        return self._is_inference

    @property
    def is_eval(self) -> bool:
        """Get module is eval or not."""
        return (not self._is_inference) and (not self.training)

    @property
    def is_train(self) -> bool:
        """Get module is train or not."""
        return (not self._is_inference) and self.training


class Transpose(nn.Module):
    """Transpose Module.

    Args:
        dim0 (int): the first dimension to be transposed.
        dim1 (int): the second dimension to be transposed
    """

    def __init__(self, dim0: int, dim1: int) -> None:
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward the module."""
        return x.transpose(self.dim0, self.dim1)


def div_no_nan(
    input: torch.Tensor,
    other: torch.Tensor,
    *,
    rounding_mode: Optional[str] = None,
    out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Divides input by other and avoid division by zero.

    Args:
        input (Tensor): the dividend
        other (Tensor): the divisor
        rounding_mode (str, optional):
            Type of rounding applied to the result:
            - None: default behavior. Performs no rounding and, if both input and other
                are integer types, promotes the inputs to the default scalar type.
                Equivalent to true division in Python (the / operator) and NumPy’s
                np.true_divide.
            - "trunc": rounds the results of the division towards zero. Equivalent to
                C-style integer division.
            - "floor": rounds the results of the division down. Equivalent to floor
                division in Python (the // operator) and NumPy’s np.floor_divide.
        out (Tensor, optional): the output tensor.

    Return:
        out (Tensor): the output tensor.
    """
    return torch.nan_to_num(
        torch.div(input, other, rounding_mode=rounding_mode, out=out),
        nan=0.0,
        posinf=0.0,
        neginf=0.0,
    )
