tzrec/models/match_model.py (270 lines of code) (raw):

# Copyright (c) 2024-2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional, Union import torch from torch import nn from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.metrics import recall_at_k from tzrec.models.model import BaseModel from tzrec.modules.embedding import EmbeddingGroup from tzrec.modules.utils import BaseModule, div_no_nan from tzrec.modules.variational_dropout import VariationalDropout from tzrec.protos import model_pb2, simi_pb2, tower_pb2 from tzrec.protos.loss_pb2 import LossConfig from tzrec.protos.metric_pb2 import MetricConfig from tzrec.protos.model_pb2 import ModelConfig from tzrec.utils.config_util import config_to_kwargs @torch.fx.wrap def _zero_int_label(pred: torch.Tensor) -> torch.Tensor: return torch.zeros((pred.size(0),), dtype=torch.int64, device=pred.device) @torch.fx.wrap def _arange_int_label(pred: torch.Tensor) -> torch.Tensor: return torch.arange(pred.size(0), dtype=torch.int64, device=pred.device) @torch.fx.wrap def _update_tensor_2_dict( tensor_dict: Dict[str, torch.Tensor], new_tensor: torch.Tensor, key: str ) -> None: tensor_dict[key] = new_tensor class MatchTower(BaseModule): """Base match tower. Args: tower_config (Tower): user/item tower config. output_dim (int): user/item output embedding dimension. similarity (Similarity): when use COSINE similarity, will norm the output embedding. feature_groups(list) (FeatureGroupConfig): feature group config. features (list): list of features. """ def __init__( self, tower_config: Union[ tower_pb2.Tower, tower_pb2.DATTower, tower_pb2.MINDUserTower, tower_pb2.MINDItemTower, ], output_dim: int, similarity: simi_pb2.Similarity, feature_groups: List[model_pb2.FeatureGroupConfig], features: List[BaseFeature], model_config: model_pb2.ModelConfig, ) -> None: super().__init__() self._tower_config = tower_config self._group_name = tower_config.input self._output_dim = output_dim self._similarity = similarity self._feature_groups = feature_groups self._features = features self._model_config = model_config self.embedding_group = None self.group_variational_dropouts = None self.group_variational_dropout_loss = {} def init_input(self) -> None: """Build embedding group and group variational dropout.""" self.embedding_group = EmbeddingGroup(self._features, self._feature_groups) if self._model_config.HasField("variational_dropout"): self.group_variational_dropouts = nn.ModuleDict() variational_dropout_config = self._model_config.variational_dropout variational_dropout_config_dict = config_to_kwargs( variational_dropout_config ) for feature_group in self._feature_groups: if feature_group.group_type != model_pb2.SEQUENCE: feature_dim = self.embedding_group.group_feature_dims( feature_group.group_name ) if len(feature_dim) > 1: variational_dropout = VariationalDropout( feature_dim, feature_group.group_name, **variational_dropout_config_dict, ) self.group_variational_dropouts[feature_group.group_name] = ( variational_dropout ) def build_input(self, batch: Batch) -> Dict[str, torch.Tensor]: """Build input feature.""" feature_dict = self.embedding_group(batch) if self.group_variational_dropouts is not None: for ( group_name, variational_dropout, ) in self.group_variational_dropouts.items(): feature, variational_dropout_loss = variational_dropout( feature_dict[group_name] ) _update_tensor_2_dict(feature_dict, feature, group_name) _update_tensor_2_dict( self.group_variational_dropout_loss, variational_dropout_loss, group_name + "_feature_p_loss", ) return feature_dict class MatchTowerWoEG(nn.Module): """Base match tower without embedding group for share embedding. Args: tower_config (Tower): user/item tower config. output_dim (int): user/item output embedding dimension. similarity (Similarity): when use COSINE similarity, will norm the output embedding. feature_group (FeatureGroupConfig): feature group config. features (list): list of features. """ def __init__( self, tower_config: Union[ tower_pb2.Tower, tower_pb2.HSTUMatchTower, ], output_dim: int, similarity: simi_pb2.Similarity, feature_group: model_pb2.FeatureGroupConfig, features: List[BaseFeature], ) -> None: super().__init__() self._tower_config = tower_config self._group_name = tower_config.input self._output_dim = output_dim self._similarity = similarity self._feature_group = feature_group self._features = features class MatchModel(BaseModel): """Base model for match. Args: model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. sample_weights (list): sample weight names """ def __init__( self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] self._sample_weight = sample_weights[0] if sample_weights else sample_weights self._in_batch_negative = False self._loss_collection = {} if self._model_config and hasattr(self._model_config, "in_batch_negative"): self._in_batch_negative = self._model_config.in_batch_negative def sim(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor: """Calculate user and item embedding similarity.""" if self._in_batch_negative: return torch.mm(user_emb, item_emb.T) else: batch_size = user_emb.size(0) pos_item_emb = item_emb[:batch_size] neg_item_emb = item_emb[batch_size:] pos_ui_sim = torch.sum( torch.multiply(user_emb, pos_item_emb), dim=-1, keepdim=True ) neg_ui_sim = torch.matmul(user_emb, neg_item_emb.transpose(0, 1)) return torch.cat([pos_ui_sim, neg_ui_sim], dim=-1) def _init_loss_impl(self, loss_cfg: LossConfig, suffix: str = "") -> None: loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix assert loss_type == "softmax_cross_entropy", ( "match model only support softmax_cross_entropy loss now." ) reduction = "none" if self._sample_weight else "mean" self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction) def init_loss(self) -> None: """Initialize loss modules.""" assert len(self._base_model_config.losses) == 1, ( "match model only support single loss now." ) for loss_cfg in self._base_model_config.losses: self._init_loss_impl(loss_cfg) def _loss_impl( self, predictions: Dict[str, torch.Tensor], batch: Batch, label: torch.Tensor, loss_cfg: LossConfig, suffix: str = "", ) -> Dict[str, torch.Tensor]: losses = {} sample_weight = ( batch.sample_weights[self._sample_weight] if self._sample_weight else torch.Tensor([1.0]) ) loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix assert loss_type == "softmax_cross_entropy", ( "match model only support softmax_cross_entropy loss now." ) pred = predictions["similarity" + suffix] if self._in_batch_negative: label = _arange_int_label(pred) else: label = _zero_int_label(pred) losses[loss_name] = self._loss_modules[loss_name](pred, label) if self._sample_weight: losses[loss_name] = div_no_nan( torch.mean(losses[loss_name] * sample_weight), torch.mean(sample_weight) ) return losses def loss( self, predictions: Dict[str, torch.Tensor], batch: Batch ) -> Dict[str, torch.Tensor]: """Compute loss of the model.""" losses = {} for loss_cfg in self._base_model_config.losses: losses.update( self._loss_impl( predictions, batch, batch.labels[self._label_name], loss_cfg ) ) losses.update(self._loss_collection) return losses def _init_metric_impl(self, metric_cfg: MetricConfig, suffix: str = "") -> None: metric_type = metric_cfg.WhichOneof("metric") metric_name = metric_type + suffix oneof_metric_cfg = getattr(metric_cfg, metric_type) metric_kwargs = config_to_kwargs(oneof_metric_cfg) if metric_type == "recall_at_k": metric_name = f"recall@{oneof_metric_cfg.top_k}" + suffix self._metric_modules[metric_name] = recall_at_k.RecallAtK(**metric_kwargs) else: raise ValueError(f"{metric_type} is not supported for this model") def init_metric(self) -> None: """Initialize metric modules.""" for metric_cfg in self._base_model_config.metrics: self._init_metric_impl(metric_cfg) for loss_cfg in self._base_model_config.losses: self._init_loss_metric_impl(loss_cfg) def _update_metric_impl( self, predictions: Dict[str, torch.Tensor], batch: Batch, label: torch.Tensor, metric_cfg: MetricConfig, suffix: str = "", ) -> None: metric_type = metric_cfg.WhichOneof("metric") metric_name = metric_type + suffix oneof_metric_cfg = getattr(metric_cfg, metric_type) if metric_type == "recall_at_k": metric_name = f"recall@{oneof_metric_cfg.top_k}" + suffix pred = predictions["similarity" + suffix] if self._in_batch_negative: label = torch.eye(*pred.size(), dtype=torch.bool, device=pred.device) else: label = torch.zeros_like(pred, dtype=torch.bool) label[:, 0] = True self._metric_modules[metric_name].update(pred, label) else: raise ValueError(f"{metric_type} is not supported for this model") def update_metric( self, predictions: Dict[str, torch.Tensor], batch: Batch, losses: Optional[Dict[str, torch.Tensor]] = None, ) -> None: """Update metric state. Args: predictions (dict): a dict of predicted result. batch (Batch): input batch data. losses (dict, optional): a dict of loss. """ for metric_cfg in self._base_model_config.metrics: self._update_metric_impl( predictions, batch, batch.labels[self._label_name], metric_cfg ) if losses is not None: for loss_cfg in self._base_model_config.losses: self._update_loss_metric_impl( losses, batch, batch.labels[self._label_name], loss_cfg ) class TowerWrapper(nn.Module): """Tower inference wrapper for jit.script.""" def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() setattr(self, tower_name, module) self._features = module._features self._tower_name = tower_name def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. Args: batch (Batch): input batch data. Return: embedding (dict): tower output embedding. """ return {f"{self._tower_name}_emb": getattr(self, self._tower_name)(batch)} class TowerWoEGWrapper(nn.Module): """Tower without embedding group inference wrapper for jit.script.""" def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None: super().__init__() self.embedding_group = EmbeddingGroup(module._features, [module._feature_group]) setattr(self, tower_name, module) self._features = module._features self._tower_name = tower_name self._group_name = module._group_name def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Forward the tower. Args: batch (Batch): input batch data. Return: embedding (dict): tower output embedding. """ grouped_features = self.embedding_group(batch) return { f"{self._tower_name}_emb": getattr(self, self._tower_name)( grouped_features[self._group_name] ) }