tzrec/models/hstu.py (142 lines of code) (raw):

# Copyright (c) 2024-2025, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F from torch._tensor import Tensor from tzrec.datasets.utils import NEG_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature from tzrec.models.match_model import MatchModel, MatchTowerWoEG from tzrec.modules.embedding import EmbeddingGroup from tzrec.modules.sequence import HSTUEncoder from tzrec.protos import model_pb2, simi_pb2, tower_pb2 from tzrec.utils import config_util @torch.fx.wrap def _update_dict_tensor( tensor_dict: Dict[str, torch.Tensor], new_tensor_dict: Optional[Dict[str, Optional[torch.Tensor]]], ) -> None: if new_tensor_dict: for k, v in new_tensor_dict.items(): if v is not None: tensor_dict[k] = v class HSTUMatchUserTower(MatchTowerWoEG): """HSTU Match model user tower. Args: tower_config (Tower): user tower config. output_dim (int): user output embedding dimension. similarity (Similarity): when use COSINE similarity, will norm the output embedding. feature_group (FeatureGroupConfig): feature group config. features (list): list of features. """ def __init__( self, tower_config: tower_pb2.HSTUMatchTower, output_dim: int, similarity: simi_pb2.Similarity, feature_group: model_pb2.FeatureGroupConfig, feature_group_dims: List[int], features: List[BaseFeature], model_config: model_pb2.ModelConfig, ) -> None: super().__init__(tower_config, output_dim, similarity, feature_group, features) self.tower_config = tower_config encoder_config = tower_config.hstu_encoder seq_config_dict = config_util.config_to_kwargs(encoder_config) sequence_dim = sum(feature_group_dims) seq_config_dict["sequence_dim"] = sequence_dim self.seq_encoder = HSTUEncoder(**seq_config_dict) def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward the tower. Args: grouped_features: Dictionary containing grouped feature tensors Returns: torch.Tensor: The output tensor from the tower """ output = self.seq_encoder(grouped_features) return output class HSTUMatchItemTower(MatchTowerWoEG): """HSTU Match model item tower. Args: tower_config (Tower): item tower config. output_dim (int): item output embedding dimension. similarity (Similarity): when use COSINE similarity, will norm the output embedding. feature_group (FeatureGroupConfig): feature group config. features (list): list of features. """ def __init__( self, tower_config: tower_pb2.Tower, output_dim: int, similarity: simi_pb2.Similarity, feature_group: model_pb2.FeatureGroupConfig, features: List[BaseFeature], ) -> None: super().__init__(tower_config, output_dim, similarity, feature_group, features) self.tower_config = tower_config def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor: """Forward the tower. Args: grouped_features: Dictionary containing grouped feature tensors Returns: torch.Tensor: The output tensor from the tower """ output = grouped_features[f"{self._group_name}.sequence"] output = F.normalize(output, p=2.0, dim=1, eps=1e-6) return output class HSTUMatch(MatchModel): """HSTU Match model. Args: model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. """ def __init__( self, model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) assert len(model_config.feature_groups) == 1 self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} feature_group = name_to_feature_group[self._model_config.hstu_tower.input] used_features = self.get_features_in_feature_groups([feature_group]) self.user_tower = HSTUMatchUserTower( self._model_config.hstu_tower, self._model_config.output_dim, self._model_config.similarity, feature_group, self.embedding_group.group_dims( self._model_config.hstu_tower.input + ".sequence" ), used_features, model_config, ) self.item_tower = HSTUMatchItemTower( self._model_config.hstu_tower, self._model_config.output_dim, self._model_config.similarity, feature_group, used_features, ) self.seq_tower_input = self._model_config.hstu_tower.input def predict(self, batch: Batch) -> Dict[str, Tensor]: """Forward the model. Args: batch (Batch): input batch data. Return: predictions (dict): a dict of predicted result. """ batch_sparse_features = batch.sparse_features[NEG_DATA_GROUP] # Get batch_size and neg_sample_size from batch_sparse_features batch_size = batch.labels[self._label_name].shape[0] neg_sample_size = batch_sparse_features.lengths()[batch_size] - 1 grouped_features = self.embedding_group(batch) item_group_features = { self.seq_tower_input + ".sequence": grouped_features[ self.seq_tower_input + ".sequence" ][batch_size:, : neg_sample_size + 1], } item_tower_emb = self.item_tower(item_group_features) user_group_features = { self.seq_tower_input + ".sequence": grouped_features[ self.seq_tower_input + ".sequence" ][:batch_size], self.seq_tower_input + ".sequence_length": grouped_features[ self.seq_tower_input + ".sequence_length" ][:batch_size], } user_tower_emb = self.user_tower(user_group_features) ui_sim = ( self.sim(user_tower_emb, item_tower_emb, neg_for_each_sample=True) / self._model_config.temperature ) return {"similarity": ui_sim} def sim( self, user_emb: torch.Tensor, item_emb: torch.Tensor, neg_for_each_sample: bool = False, ) -> torch.Tensor: """Override the sim method in MatchModel to calculate similarity.""" if self._in_batch_negative: return torch.mm(user_emb, item_emb.T) else: batch_size = user_emb.size(0) pos_item_emb = item_emb[:, 0] neg_item_emb = item_emb[:, 1:].reshape(-1, item_emb.shape[-1]) pos_ui_sim = torch.sum( torch.multiply(user_emb, pos_item_emb), dim=-1, keepdim=True ) neg_ui_sim = None if not neg_for_each_sample: neg_ui_sim = torch.matmul(user_emb, neg_item_emb.transpose(0, 1)) else: num_neg_per_user = neg_item_emb.size(0) // batch_size neg_size = batch_size * num_neg_per_user neg_item_emb = neg_item_emb[:neg_size] neg_item_emb = neg_item_emb.view(batch_size, num_neg_per_user, -1) neg_ui_sim = torch.sum(user_emb.unsqueeze(1) * neg_item_emb, dim=-1) return torch.cat([pos_ui_sim, neg_ui_sim], dim=-1)