tzrec/modules/masknet_module.py (92 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from torch import nn
from tzrec.modules.mlp import MLP
from tzrec.protos.module_pb2 import MaskNetModule as MaskNetModuleConfig
from tzrec.utils.config_util import config_to_kwargs
class MaskBlock(nn.Module):
"""MaskBlock module.
Args:
input_dim (int): Input dimension, either feature embedding dim(parallel mode)
or hidden state dim(serial mode).
mask_input_dim (int): Mask input dimension, is always the feature embedding dim
for both para and serial modes.
reduction_ratio (float): Reduction ratio, aggregation_dim / mask_input_dim.
aggregation_dim (int): Aggregation layer dim, mask_input_dim*reduction_ratio.
hidden_dim (int): Hidden layer dimension for feedforward network.
"""
def __init__(
self,
input_dim: int,
mask_input_dim: int,
reduction_ratio: float,
aggregation_dim: int,
hidden_dim: int,
) -> None:
super(MaskBlock, self).__init__()
self.ln_emb = nn.LayerNorm(input_dim)
if not aggregation_dim and not reduction_ratio:
raise ValueError(
"Either aggregation_dim or reduction_ratio must be provided."
)
if aggregation_dim:
self.aggregation_dim = aggregation_dim
if reduction_ratio:
self.aggregation_dim = int(mask_input_dim * reduction_ratio)
assert self.aggregation_dim > 0, (
"aggregation_dim must be > 0, check your aggregation_dim or "
)
"redudction_ratio settings."
self.mask_generator = nn.Sequential(
nn.Linear(mask_input_dim, self.aggregation_dim),
nn.ReLU(),
nn.Linear(self.aggregation_dim, input_dim),
)
assert hidden_dim > 0, "hidden_dim must be > 0."
self.ffn = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
)
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
"""Forward pass of MaskBlock."""
ln_emb = self.ln_emb(input)
weights = self.mask_generator(mask_input)
weighted_emb = ln_emb * weights
output = self.ffn(weighted_emb)
return output
class MaskNetModule(nn.Module):
"""Masknet module.
Args:
model_config (ModelConfig): an instance of ModelConfig.
feature_dim (int): input feature dim.
"""
def __init__(
self,
module_config: MaskNetModuleConfig,
feature_dim: int,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.use_parallel = module_config.use_parallel
self.mask_blocks = nn.ModuleList(
[
MaskBlock(
feature_dim,
feature_dim,
module_config.mask_block.reduction_ratio,
module_config.mask_block.aggregation_dim,
module_config.mask_block.hidden_dim,
)
for _ in range(module_config.n_mask_blocks)
]
)
if self.use_parallel:
self.top_mlp = MLP(
in_features=feature_dim * module_config.n_mask_blocks,
**config_to_kwargs(module_config.top_mlp),
)
else:
self.top_mlp = MLP(
in_features=feature_dim,
**config_to_kwargs(module_config.top_mlp),
)
def forward(self, feature_emb: torch.Tensor) -> torch.Tensor:
"""Forward method."""
if self.use_parallel: # parallel mask blocks
hidden = torch.concat(
[
self.mask_blocks[i](feature_emb, feature_emb)
for i in range(len(self.mask_blocks))
],
dim=-1,
)
else: # serial mask blocks
hidden = self.mask_blocks[0](feature_emb, feature_emb)
for i in range(1, len(self.mask_blocks)):
hidden = self.mask_blocks[i](hidden, feature_emb)
return self.top_mlp(hidden)