tzrec/models/hstu.py (142 lines of code) (raw):
# Copyright (c) 2024-2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch._tensor import Tensor
from tzrec.datasets.utils import NEG_DATA_GROUP, Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.match_model import MatchModel, MatchTowerWoEG
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.sequence import HSTUEncoder
from tzrec.protos import model_pb2, simi_pb2, tower_pb2
from tzrec.utils import config_util
@torch.fx.wrap
def _update_dict_tensor(
tensor_dict: Dict[str, torch.Tensor],
new_tensor_dict: Optional[Dict[str, Optional[torch.Tensor]]],
) -> None:
if new_tensor_dict:
for k, v in new_tensor_dict.items():
if v is not None:
tensor_dict[k] = v
class HSTUMatchUserTower(MatchTowerWoEG):
"""HSTU Match model user tower.
Args:
tower_config (Tower): user tower config.
output_dim (int): user output embedding dimension.
similarity (Similarity): when use COSINE similarity,
will norm the output embedding.
feature_group (FeatureGroupConfig): feature group config.
features (list): list of features.
"""
def __init__(
self,
tower_config: tower_pb2.HSTUMatchTower,
output_dim: int,
similarity: simi_pb2.Similarity,
feature_group: model_pb2.FeatureGroupConfig,
feature_group_dims: List[int],
features: List[BaseFeature],
model_config: model_pb2.ModelConfig,
) -> None:
super().__init__(tower_config, output_dim, similarity, feature_group, features)
self.tower_config = tower_config
encoder_config = tower_config.hstu_encoder
seq_config_dict = config_util.config_to_kwargs(encoder_config)
sequence_dim = sum(feature_group_dims)
seq_config_dict["sequence_dim"] = sequence_dim
self.seq_encoder = HSTUEncoder(**seq_config_dict)
def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the tower.
Args:
grouped_features: Dictionary containing grouped feature tensors
Returns:
torch.Tensor: The output tensor from the tower
"""
output = self.seq_encoder(grouped_features)
return output
class HSTUMatchItemTower(MatchTowerWoEG):
"""HSTU Match model item tower.
Args:
tower_config (Tower): item tower config.
output_dim (int): item output embedding dimension.
similarity (Similarity): when use COSINE similarity,
will norm the output embedding.
feature_group (FeatureGroupConfig): feature group config.
features (list): list of features.
"""
def __init__(
self,
tower_config: tower_pb2.Tower,
output_dim: int,
similarity: simi_pb2.Similarity,
feature_group: model_pb2.FeatureGroupConfig,
features: List[BaseFeature],
) -> None:
super().__init__(tower_config, output_dim, similarity, feature_group, features)
self.tower_config = tower_config
def forward(self, grouped_features: Dict[str, torch.Tensor]) -> torch.Tensor:
"""Forward the tower.
Args:
grouped_features: Dictionary containing grouped feature tensors
Returns:
torch.Tensor: The output tensor from the tower
"""
output = grouped_features[f"{self._group_name}.sequence"]
output = F.normalize(output, p=2.0, dim=1, eps=1e-6)
return output
class HSTUMatch(MatchModel):
"""HSTU Match model.
Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
"""
def __init__(
self,
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels, sample_weights, **kwargs)
assert len(model_config.feature_groups) == 1
self.embedding_group = EmbeddingGroup(
features, list(model_config.feature_groups)
)
name_to_feature_group = {x.group_name: x for x in model_config.feature_groups}
feature_group = name_to_feature_group[self._model_config.hstu_tower.input]
used_features = self.get_features_in_feature_groups([feature_group])
self.user_tower = HSTUMatchUserTower(
self._model_config.hstu_tower,
self._model_config.output_dim,
self._model_config.similarity,
feature_group,
self.embedding_group.group_dims(
self._model_config.hstu_tower.input + ".sequence"
),
used_features,
model_config,
)
self.item_tower = HSTUMatchItemTower(
self._model_config.hstu_tower,
self._model_config.output_dim,
self._model_config.similarity,
feature_group,
used_features,
)
self.seq_tower_input = self._model_config.hstu_tower.input
def predict(self, batch: Batch) -> Dict[str, Tensor]:
"""Forward the model.
Args:
batch (Batch): input batch data.
Return:
predictions (dict): a dict of predicted result.
"""
batch_sparse_features = batch.sparse_features[NEG_DATA_GROUP]
# Get batch_size and neg_sample_size from batch_sparse_features
batch_size = batch.labels[self._label_name].shape[0]
neg_sample_size = batch_sparse_features.lengths()[batch_size] - 1
grouped_features = self.embedding_group(batch)
item_group_features = {
self.seq_tower_input + ".sequence": grouped_features[
self.seq_tower_input + ".sequence"
][batch_size:, : neg_sample_size + 1],
}
item_tower_emb = self.item_tower(item_group_features)
user_group_features = {
self.seq_tower_input + ".sequence": grouped_features[
self.seq_tower_input + ".sequence"
][:batch_size],
self.seq_tower_input + ".sequence_length": grouped_features[
self.seq_tower_input + ".sequence_length"
][:batch_size],
}
user_tower_emb = self.user_tower(user_group_features)
ui_sim = (
self.sim(user_tower_emb, item_tower_emb, neg_for_each_sample=True)
/ self._model_config.temperature
)
return {"similarity": ui_sim}
def sim(
self,
user_emb: torch.Tensor,
item_emb: torch.Tensor,
neg_for_each_sample: bool = False,
) -> torch.Tensor:
"""Override the sim method in MatchModel to calculate similarity."""
if self._in_batch_negative:
return torch.mm(user_emb, item_emb.T)
else:
batch_size = user_emb.size(0)
pos_item_emb = item_emb[:, 0]
neg_item_emb = item_emb[:, 1:].reshape(-1, item_emb.shape[-1])
pos_ui_sim = torch.sum(
torch.multiply(user_emb, pos_item_emb), dim=-1, keepdim=True
)
neg_ui_sim = None
if not neg_for_each_sample:
neg_ui_sim = torch.matmul(user_emb, neg_item_emb.transpose(0, 1))
else:
num_neg_per_user = neg_item_emb.size(0) // batch_size
neg_size = batch_size * num_neg_per_user
neg_item_emb = neg_item_emb[:neg_size]
neg_item_emb = neg_item_emb.view(batch_size, num_neg_per_user, -1)
neg_ui_sim = torch.sum(user_emb.unsqueeze(1) * neg_item_emb, dim=-1)
return torch.cat([pos_ui_sim, neg_ui_sim], dim=-1)