tzrec/modules/activation.py (61 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Optional import torch from torch import nn from torch.nn import functional as F from tzrec.utils.load_class import load_by_path from tzrec.utils.logging_util import logger class Dice(nn.Module): """Data Adaptive Activation Function in DIN. Args: hidden_size (int): hidden dim of input. dim: input dims. """ def __init__( self, hidden_size: int, dim: int = 2, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() assert dim in [2, 3] self.bn = nn.BatchNorm1d(hidden_size, affine=False) # pyre-ignore [6] self.alpha = nn.Parameter(torch.empty((hidden_size,), **factory_kwargs)) self.dim = dim self.reset_parameters() def reset_parameters(self) -> None: """Initialize parameters.""" nn.init.zeros_(self.alpha) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward the module.""" if self.dim == 2: x_p = F.sigmoid(self.bn(x)) out = self.alpha * (1 - x_p) * x + x_p * x else: x = x.transpose(1, 2) x_p = F.sigmoid(self.bn(x)) out = self.alpha.unsqueeze(1) * (1 - x_p) * x + x_p * x out = out.transpose(1, 2) return out def create_activation(act_str: str = "nn.ReLU", **kwargs: Any) -> Optional[nn.Module]: """Create activation module.""" act_str = act_str.strip() act_module = None if act_str == "Dice": assert "hidden_size" in kwargs and "dim" in kwargs, ( "Dice activation method need hidden_size and dim params." ) hidden_size = kwargs["hidden_size"] dim = kwargs["dim"] act_module = Dice(hidden_size, dim) elif len(act_str) > 0: act_strs = act_str.strip(")").split("(", 1) act_class = load_by_path(act_strs[0]) if act_class is None: logger.error(f"Unknown activation [{act_str}]") else: act_params = {} if len(act_strs) > 1: try: act_params = { kv.split("=")[0]: eval(kv.split("=")[1]) for kv in act_strs[1].split(",") } except Exception as e: logger.error(f"Can not parse activation [{act_str}]") raise e act_module = act_class(**act_params) return act_module