# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import torch
from torch import nn
from torch.nn import functional as F

from tzrec.utils.load_class import load_by_path
from tzrec.utils.logging_util import logger


class Dice(nn.Module):
    """Data Adaptive Activation Function in DIN.

    Args:
        hidden_size (int): hidden dim of input.
        dim: input dims.
    """

    def __init__(
        self,
        hidden_size: int,
        dim: int = 2,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        assert dim in [2, 3]
        self.bn = nn.BatchNorm1d(hidden_size, affine=False)
        # pyre-ignore [6]
        self.alpha = nn.Parameter(torch.empty((hidden_size,), **factory_kwargs))
        self.dim = dim
        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize parameters."""
        nn.init.zeros_(self.alpha)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward the module."""
        if self.dim == 2:
            x_p = F.sigmoid(self.bn(x))
            out = self.alpha * (1 - x_p) * x + x_p * x
        else:
            x = x.transpose(1, 2)
            x_p = F.sigmoid(self.bn(x))
            out = self.alpha.unsqueeze(1) * (1 - x_p) * x + x_p * x
            out = out.transpose(1, 2)
        return out


def create_activation(act_str: str = "nn.ReLU", **kwargs: Any) -> Optional[nn.Module]:
    """Create activation module."""
    act_str = act_str.strip()

    act_module = None
    if act_str == "Dice":
        assert "hidden_size" in kwargs and "dim" in kwargs, (
            "Dice activation method need hidden_size and dim params."
        )
        hidden_size = kwargs["hidden_size"]
        dim = kwargs["dim"]
        act_module = Dice(hidden_size, dim)
    elif len(act_str) > 0:
        act_strs = act_str.strip(")").split("(", 1)

        act_class = load_by_path(act_strs[0])
        if act_class is None:
            logger.error(f"Unknown activation [{act_str}]")
        else:
            act_params = {}
            if len(act_strs) > 1:
                try:
                    act_params = {
                        kv.split("=")[0]: eval(kv.split("=")[1])
                        for kv in act_strs[1].split(",")
                    }
                except Exception as e:
                    logger.error(f"Can not parse activation [{act_str}]")
                    raise e
            act_module = act_class(**act_params)

    return act_module
