tzrec/models/multi_tower_din_trt.py (96 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, List, Optional
import torch
from torch import nn
from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.rank_model import RankModel
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.mlp import MLP
from tzrec.modules.sequence import DINEncoder
from tzrec.protos.model_pb2 import ModelConfig
from tzrec.utils.config_util import config_to_kwargs
@torch.fx.wrap
def _get_dict(
grouped_features_keys: List[str], args: List[torch.Tensor]
) -> Dict[str, torch.Tensor]:
if len(grouped_features_keys) != len(args):
raise ValueError(
"The number of grouped_features_keys must match the number of arguments."
)
grouped_features = {key: value for key, value in zip(grouped_features_keys, args)}
return grouped_features
class MultiTowerDINDense(RankModel):
"""DIN Dense model.
Args:
embedding_group(EmbeddingGroup): Embedding Group
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""
def __init__(
self,
embedding_group: EmbeddingGroup,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self.grouped_features_keys = embedding_group.grouped_features_keys()
self.towers = nn.ModuleDict()
total_tower_dim = 0
for tower in self._model_config.towers:
group_name = tower.input
tower_feature_in = embedding_group.group_total_dim(group_name)
tower_mlp = MLP(tower_feature_in, **config_to_kwargs(tower.mlp))
self.towers[group_name] = tower_mlp
total_tower_dim += tower_mlp.output_dim()
self.din_towers = nn.ModuleList()
for tower in self._model_config.din_towers:
group_name = tower.input
sequence_dim = embedding_group.group_total_dim(f"{group_name}.sequence")
query_dim = embedding_group.group_total_dim(f"{group_name}.query")
tower_din = DINEncoder(
sequence_dim,
query_dim,
group_name,
attn_mlp=config_to_kwargs(tower.attn_mlp),
)
self.din_towers.append(tower_din)
total_tower_dim += tower_din.output_dim()
final_dim = total_tower_dim
if self._model_config.HasField("final"):
self.final_mlp = MLP(
in_features=total_tower_dim,
**config_to_kwargs(self._model_config.final),
)
final_dim = self.final_mlp.output_dim()
self.output_mlp = nn.Linear(final_dim, self._num_class)
# pyre-ignore [14]
def predict(self, args: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Forward the module."""
grouped_features = _get_dict(self.grouped_features_keys, args)
tower_outputs = []
for k, tower_mlp in self.towers.items():
tower_outputs.append(tower_mlp(grouped_features[k]))
for tower_din in self.din_towers:
tower_outputs.append(tower_din(grouped_features))
tower_output = torch.cat(tower_outputs, dim=-1)
if self._model_config.HasField("final"):
tower_output = self.final_mlp(tower_output)
y = self.output_mlp(tower_output)
return self._output_to_prediction(y)
# pyre-ignore [14]
def forward(self, args: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Forward the module."""
# use predict to avoid trace error in self._output_to_prediction(y)
return self.predict(args)
class MultiTowerDINTRT(RankModel):
"""DIN model.
Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""
def __init__(
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self.embedding_group = EmbeddingGroup(
features, list(model_config.feature_groups)
)
self.dense = MultiTowerDINDense(
self.embedding_group, model_config, features, labels
)
def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Forward the model.
Args:
batch (Batch): input batch data.
Return:
predictions (dict): a dict of predicted result.
"""
grouped_features = self.embedding_group.predict(batch)
y = self.dense.predict(grouped_features)
return y