tzrec/models/multi_tower_din_trt.py (96 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Copyright (c) Alibaba, Inc. and its affiliates. from typing import Any, Dict, List, Optional import torch from torch import nn from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.models.rank_model import RankModel from tzrec.modules.embedding import EmbeddingGroup from tzrec.modules.mlp import MLP from tzrec.modules.sequence import DINEncoder from tzrec.protos.model_pb2 import ModelConfig from tzrec.utils.config_util import config_to_kwargs @torch.fx.wrap def _get_dict( grouped_features_keys: List[str], args: List[torch.Tensor] ) -> Dict[str, torch.Tensor]: if len(grouped_features_keys) != len(args): raise ValueError( "The number of grouped_features_keys must match the number of arguments." ) grouped_features = {key: value for key, value in zip(grouped_features_keys, args)} return grouped_features class MultiTowerDINDense(RankModel): """DIN Dense model. Args: embedding_group(EmbeddingGroup): Embedding Group model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. sample_weights (list): sample weight names. """ def __init__( self, embedding_group: EmbeddingGroup, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self.grouped_features_keys = embedding_group.grouped_features_keys() self.towers = nn.ModuleDict() total_tower_dim = 0 for tower in self._model_config.towers: group_name = tower.input tower_feature_in = embedding_group.group_total_dim(group_name) tower_mlp = MLP(tower_feature_in, **config_to_kwargs(tower.mlp)) self.towers[group_name] = tower_mlp total_tower_dim += tower_mlp.output_dim() self.din_towers = nn.ModuleList() for tower in self._model_config.din_towers: group_name = tower.input sequence_dim = embedding_group.group_total_dim(f"{group_name}.sequence") query_dim = embedding_group.group_total_dim(f"{group_name}.query") tower_din = DINEncoder( sequence_dim, query_dim, group_name, attn_mlp=config_to_kwargs(tower.attn_mlp), ) self.din_towers.append(tower_din) total_tower_dim += tower_din.output_dim() final_dim = total_tower_dim if self._model_config.HasField("final"): self.final_mlp = MLP( in_features=total_tower_dim, **config_to_kwargs(self._model_config.final), ) final_dim = self.final_mlp.output_dim() self.output_mlp = nn.Linear(final_dim, self._num_class) # pyre-ignore [14] def predict(self, args: List[torch.Tensor]) -> Dict[str, torch.Tensor]: """Forward the module.""" grouped_features = _get_dict(self.grouped_features_keys, args) tower_outputs = [] for k, tower_mlp in self.towers.items(): tower_outputs.append(tower_mlp(grouped_features[k])) for tower_din in self.din_towers: tower_outputs.append(tower_din(grouped_features)) tower_output = torch.cat(tower_outputs, dim=-1) if self._model_config.HasField("final"): tower_output = self.final_mlp(tower_output) y = self.output_mlp(tower_output) return self._output_to_prediction(y) # pyre-ignore [14] def forward(self, args: List[torch.Tensor]) -> Dict[str, torch.Tensor]: """Forward the module.""" # use predict to avoid trace error in self._output_to_prediction(y) return self.predict(args) class MultiTowerDINTRT(RankModel): """DIN model. Args: model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. sample_weights (list): sample weight names. """ def __init__( self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) self.dense = MultiTowerDINDense( self.embedding_group, model_config, features, labels ) def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the model. Args: batch (Batch): input batch data. Return: predictions (dict): a dict of predicted result. """ grouped_features = self.embedding_group.predict(batch) y = self.dense.predict(grouped_features) return y