tzrec/modules/utils.py (59 lines of code) (raw):

# Copyright (c) 2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc from typing import Any, Optional import torch from torch import nn from tzrec.ops import Kernel class BaseModule(nn.Module, abc.ABC): """TorchEasyRec Base Module. Args: is_inference (bool): is inference or not. kernel (Optional[Kernel]): module kernel type. """ def __init__( self, is_inference: bool = False, kernel: Optional[Kernel] = None, ) -> None: super().__init__() self._is_inference = is_inference self._kernel = kernel def kernel(self) -> Kernel: """Get kernel type.""" kernel = self._kernel if kernel is not None: return kernel else: return Kernel.TRITON # pyre-ignore [2] def recursive_setattr(self, name: str, value: Any) -> None: """Recursive set sub module attrs.""" for _, module in self.named_modules(): if hasattr(module, name): setattr(module, name, value) def set_is_inference(self, is_inference: bool) -> None: """Set module in inference or not.""" self._is_inference = is_inference self.recursive_setattr("_is_inference", is_inference) def set_kernel(self, kernel: Kernel) -> None: """Set module kernel type.""" self._kernel = kernel self.recursive_setattr("_kernel", kernel) @property def is_inference(self) -> bool: """Get module is inference or not.""" return self._is_inference @property def is_eval(self) -> bool: """Get module is eval or not.""" return (not self._is_inference) and (not self.training) @property def is_train(self) -> bool: """Get module is train or not.""" return (not self._is_inference) and self.training class Transpose(nn.Module): """Transpose Module. Args: dim0 (int): the first dimension to be transposed. dim1 (int): the second dimension to be transposed """ def __init__(self, dim0: int, dim1: int) -> None: super().__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward the module.""" return x.transpose(self.dim0, self.dim1) def div_no_nan( input: torch.Tensor, other: torch.Tensor, *, rounding_mode: Optional[str] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Divides input by other and avoid division by zero. Args: input (Tensor): the dividend other (Tensor): the divisor rounding_mode (str, optional): Type of rounding applied to the result: - None: default behavior. Performs no rounding and, if both input and other are integer types, promotes the inputs to the default scalar type. Equivalent to true division in Python (the / operator) and NumPy’s np.true_divide. - "trunc": rounds the results of the division towards zero. Equivalent to C-style integer division. - "floor": rounds the results of the division down. Equivalent to floor division in Python (the // operator) and NumPy’s np.floor_divide. out (Tensor, optional): the output tensor. Return: out (Tensor): the output tensor. """ return torch.nan_to_num( torch.div(input, other, rounding_mode=rounding_mode, out=out), nan=0.0, posinf=0.0, neginf=0.0, )