tzrec/modules/mlp.py (108 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, List, Optional, Union import torch from torch import nn from tzrec.modules.activation import create_activation from tzrec.modules.utils import Transpose class Perceptron(nn.Module): """Applies a linear transformation and activation. Args: in_features (int): number of elements in each input sample. out_features (int): number of elements in each output sample. activation (str, optional): the activation function to apply to the output of linear transformation. Default: torch.nn.Relu. use_bn (bool): use batch_norm or not. bias (bool): if set to False, the layer will not learn an additive bias. dropout_ratio (float): dropout ratio of the layer. dim (int): input dims. """ def __init__( self, in_features: int, out_features: int, activation: Optional[str] = "nn.ReLU", use_bn: bool = False, bias: bool = True, dropout_ratio: float = 0.0, dim: int = 2, ) -> None: super().__init__() self.activation = activation self.use_bn = use_bn self.dropout_ratio = dropout_ratio self.perceptron = nn.Sequential( nn.Linear(in_features, out_features, bias=False if use_bn else bias) ) if use_bn: assert dim in [2, 3] if dim == 3: self.perceptron.append(Transpose(1, 2)) self.perceptron.append(nn.BatchNorm1d(out_features)) if dim == 3: self.perceptron.append(Transpose(1, 2)) if activation and len(activation) > 0: act_module = create_activation( activation, hidden_size=out_features, dim=dim ) if act_module: self.perceptron.append(act_module) else: raise ValueError(f"Unknown activation method: {activation}") if dropout_ratio > 0.0: self.perceptron.append(nn.Dropout(dropout_ratio)) def forward(self, input: torch.Tensor) -> torch.Tensor: """Forward the module.""" return self.perceptron(input) class MLP(nn.Module): """Applies a stack of Perceptron modules sequentially. Args: in_features (int): in_size of the input. hidden_units (list): out_size of each Perceptron module. bias (bool): if set to False, the layer will not learn an additive bias. Default: True. activation (str, optional): the activation function to apply to the output of linear transformation. Default: torch.nn.ReLU. use_bn (bool): use batch_norm or not. dropout_ratio (float|list, optional): dropout ratio of each layer. dim (int): input dims. return_hidden_layer_feature (bool): output hidden layer or not. """ def __init__( self, in_features: int, hidden_units: List[int], bias: bool = True, activation: Optional[str] = "nn.ReLU", use_bn: bool = False, dropout_ratio: Optional[Union[List[float], float]] = None, dim: int = 2, return_hidden_layer_feature: bool = False, ) -> None: super().__init__() self.hidden_units = hidden_units self.activation = activation self.use_bn = use_bn self.return_hidden_layer_feature = return_hidden_layer_feature if dropout_ratio is None: dropout_ratio = [0.0] * len(hidden_units) elif isinstance(dropout_ratio, list): if len(dropout_ratio) == 0: dropout_ratio = [0.0] * len(hidden_units) elif len(dropout_ratio) == 1: dropout_ratio = dropout_ratio * len(hidden_units) else: assert len(dropout_ratio) == len(hidden_units), ( "length of dropout_ratio and hidden_units must be same, " f"but got {len(dropout_ratio)} vs {len(hidden_units)}" ) else: dropout_ratio = [dropout_ratio] * len(hidden_units) self.dropout_ratio = dropout_ratio self.mlp = nn.ModuleList() for i in range(len(hidden_units)): if i == 0: in_features = in_features else: in_features = hidden_units[i - 1] self.mlp.append( Perceptron( in_features=in_features, out_features=hidden_units[i], activation=activation, use_bn=use_bn, bias=bias, dropout_ratio=dropout_ratio[i], dim=dim, ) ) def output_dim(self) -> int: """Output dimension of the module.""" return self.hidden_units[-1] def forward( self, input: torch.Tensor ) -> Union[Dict[str, torch.Tensor], torch.Tensor]: """Forward the module.""" net = input hidden_feature_dict = {} for i, tmp_mlp in enumerate(self.mlp): net = tmp_mlp(net) if self.return_hidden_layer_feature: hidden_feature_dict["hidden_layer" + str(i)] = net if i + 1 == len(self.mlp): hidden_feature_dict["hidden_layer_end"] = net if self.return_hidden_layer_feature: return hidden_feature_dict else: return net