tzrec/models/match_model.py (270 lines of code) (raw):
# Copyright (c) 2024-2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Union
import torch
from torch import nn
from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.metrics import recall_at_k
from tzrec.models.model import BaseModel
from tzrec.modules.embedding import EmbeddingGroup
from tzrec.modules.utils import BaseModule, div_no_nan
from tzrec.modules.variational_dropout import VariationalDropout
from tzrec.protos import model_pb2, simi_pb2, tower_pb2
from tzrec.protos.loss_pb2 import LossConfig
from tzrec.protos.metric_pb2 import MetricConfig
from tzrec.protos.model_pb2 import ModelConfig
from tzrec.utils.config_util import config_to_kwargs
@torch.fx.wrap
def _zero_int_label(pred: torch.Tensor) -> torch.Tensor:
return torch.zeros((pred.size(0),), dtype=torch.int64, device=pred.device)
@torch.fx.wrap
def _arange_int_label(pred: torch.Tensor) -> torch.Tensor:
return torch.arange(pred.size(0), dtype=torch.int64, device=pred.device)
@torch.fx.wrap
def _update_tensor_2_dict(
tensor_dict: Dict[str, torch.Tensor], new_tensor: torch.Tensor, key: str
) -> None:
tensor_dict[key] = new_tensor
class MatchTower(BaseModule):
"""Base match tower.
Args:
tower_config (Tower): user/item tower config.
output_dim (int): user/item output embedding dimension.
similarity (Similarity): when use COSINE similarity,
will norm the output embedding.
feature_groups(list) (FeatureGroupConfig): feature group config.
features (list): list of features.
"""
def __init__(
self,
tower_config: Union[
tower_pb2.Tower,
tower_pb2.DATTower,
tower_pb2.MINDUserTower,
tower_pb2.MINDItemTower,
],
output_dim: int,
similarity: simi_pb2.Similarity,
feature_groups: List[model_pb2.FeatureGroupConfig],
features: List[BaseFeature],
model_config: model_pb2.ModelConfig,
) -> None:
super().__init__()
self._tower_config = tower_config
self._group_name = tower_config.input
self._output_dim = output_dim
self._similarity = similarity
self._feature_groups = feature_groups
self._features = features
self._model_config = model_config
self.embedding_group = None
self.group_variational_dropouts = None
self.group_variational_dropout_loss = {}
def init_input(self) -> None:
"""Build embedding group and group variational dropout."""
self.embedding_group = EmbeddingGroup(self._features, self._feature_groups)
if self._model_config.HasField("variational_dropout"):
self.group_variational_dropouts = nn.ModuleDict()
variational_dropout_config = self._model_config.variational_dropout
variational_dropout_config_dict = config_to_kwargs(
variational_dropout_config
)
for feature_group in self._feature_groups:
if feature_group.group_type != model_pb2.SEQUENCE:
feature_dim = self.embedding_group.group_feature_dims(
feature_group.group_name
)
if len(feature_dim) > 1:
variational_dropout = VariationalDropout(
feature_dim,
feature_group.group_name,
**variational_dropout_config_dict,
)
self.group_variational_dropouts[feature_group.group_name] = (
variational_dropout
)
def build_input(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Build input feature."""
feature_dict = self.embedding_group(batch)
if self.group_variational_dropouts is not None:
for (
group_name,
variational_dropout,
) in self.group_variational_dropouts.items():
feature, variational_dropout_loss = variational_dropout(
feature_dict[group_name]
)
_update_tensor_2_dict(feature_dict, feature, group_name)
_update_tensor_2_dict(
self.group_variational_dropout_loss,
variational_dropout_loss,
group_name + "_feature_p_loss",
)
return feature_dict
class MatchTowerWoEG(nn.Module):
"""Base match tower without embedding group for share embedding.
Args:
tower_config (Tower): user/item tower config.
output_dim (int): user/item output embedding dimension.
similarity (Similarity): when use COSINE similarity,
will norm the output embedding.
feature_group (FeatureGroupConfig): feature group config.
features (list): list of features.
"""
def __init__(
self,
tower_config: Union[
tower_pb2.Tower,
tower_pb2.HSTUMatchTower,
],
output_dim: int,
similarity: simi_pb2.Similarity,
feature_group: model_pb2.FeatureGroupConfig,
features: List[BaseFeature],
) -> None:
super().__init__()
self._tower_config = tower_config
self._group_name = tower_config.input
self._output_dim = output_dim
self._similarity = similarity
self._feature_group = feature_group
self._features = features
class MatchModel(BaseModel):
"""Base model for match.
Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names
"""
def __init__(
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self._num_class = model_config.num_class
self._label_name = labels[0]
self._sample_weight = sample_weights[0] if sample_weights else sample_weights
self._in_batch_negative = False
self._loss_collection = {}
if self._model_config and hasattr(self._model_config, "in_batch_negative"):
self._in_batch_negative = self._model_config.in_batch_negative
def sim(self, user_emb: torch.Tensor, item_emb: torch.Tensor) -> torch.Tensor:
"""Calculate user and item embedding similarity."""
if self._in_batch_negative:
return torch.mm(user_emb, item_emb.T)
else:
batch_size = user_emb.size(0)
pos_item_emb = item_emb[:batch_size]
neg_item_emb = item_emb[batch_size:]
pos_ui_sim = torch.sum(
torch.multiply(user_emb, pos_item_emb), dim=-1, keepdim=True
)
neg_ui_sim = torch.matmul(user_emb, neg_item_emb.transpose(0, 1))
return torch.cat([pos_ui_sim, neg_ui_sim], dim=-1)
def _init_loss_impl(self, loss_cfg: LossConfig, suffix: str = "") -> None:
loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
assert loss_type == "softmax_cross_entropy", (
"match model only support softmax_cross_entropy loss now."
)
reduction = "none" if self._sample_weight else "mean"
self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction)
def init_loss(self) -> None:
"""Initialize loss modules."""
assert len(self._base_model_config.losses) == 1, (
"match model only support single loss now."
)
for loss_cfg in self._base_model_config.losses:
self._init_loss_impl(loss_cfg)
def _loss_impl(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
label: torch.Tensor,
loss_cfg: LossConfig,
suffix: str = "",
) -> Dict[str, torch.Tensor]:
losses = {}
sample_weight = (
batch.sample_weights[self._sample_weight]
if self._sample_weight
else torch.Tensor([1.0])
)
loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
assert loss_type == "softmax_cross_entropy", (
"match model only support softmax_cross_entropy loss now."
)
pred = predictions["similarity" + suffix]
if self._in_batch_negative:
label = _arange_int_label(pred)
else:
label = _zero_int_label(pred)
losses[loss_name] = self._loss_modules[loss_name](pred, label)
if self._sample_weight:
losses[loss_name] = div_no_nan(
torch.mean(losses[loss_name] * sample_weight), torch.mean(sample_weight)
)
return losses
def loss(
self, predictions: Dict[str, torch.Tensor], batch: Batch
) -> Dict[str, torch.Tensor]:
"""Compute loss of the model."""
losses = {}
for loss_cfg in self._base_model_config.losses:
losses.update(
self._loss_impl(
predictions, batch, batch.labels[self._label_name], loss_cfg
)
)
losses.update(self._loss_collection)
return losses
def _init_metric_impl(self, metric_cfg: MetricConfig, suffix: str = "") -> None:
metric_type = metric_cfg.WhichOneof("metric")
metric_name = metric_type + suffix
oneof_metric_cfg = getattr(metric_cfg, metric_type)
metric_kwargs = config_to_kwargs(oneof_metric_cfg)
if metric_type == "recall_at_k":
metric_name = f"recall@{oneof_metric_cfg.top_k}" + suffix
self._metric_modules[metric_name] = recall_at_k.RecallAtK(**metric_kwargs)
else:
raise ValueError(f"{metric_type} is not supported for this model")
def init_metric(self) -> None:
"""Initialize metric modules."""
for metric_cfg in self._base_model_config.metrics:
self._init_metric_impl(metric_cfg)
for loss_cfg in self._base_model_config.losses:
self._init_loss_metric_impl(loss_cfg)
def _update_metric_impl(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
label: torch.Tensor,
metric_cfg: MetricConfig,
suffix: str = "",
) -> None:
metric_type = metric_cfg.WhichOneof("metric")
metric_name = metric_type + suffix
oneof_metric_cfg = getattr(metric_cfg, metric_type)
if metric_type == "recall_at_k":
metric_name = f"recall@{oneof_metric_cfg.top_k}" + suffix
pred = predictions["similarity" + suffix]
if self._in_batch_negative:
label = torch.eye(*pred.size(), dtype=torch.bool, device=pred.device)
else:
label = torch.zeros_like(pred, dtype=torch.bool)
label[:, 0] = True
self._metric_modules[metric_name].update(pred, label)
else:
raise ValueError(f"{metric_type} is not supported for this model")
def update_metric(
self,
predictions: Dict[str, torch.Tensor],
batch: Batch,
losses: Optional[Dict[str, torch.Tensor]] = None,
) -> None:
"""Update metric state.
Args:
predictions (dict): a dict of predicted result.
batch (Batch): input batch data.
losses (dict, optional): a dict of loss.
"""
for metric_cfg in self._base_model_config.metrics:
self._update_metric_impl(
predictions, batch, batch.labels[self._label_name], metric_cfg
)
if losses is not None:
for loss_cfg in self._base_model_config.losses:
self._update_loss_metric_impl(
losses, batch, batch.labels[self._label_name], loss_cfg
)
class TowerWrapper(nn.Module):
"""Tower inference wrapper for jit.script."""
def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None:
super().__init__()
setattr(self, tower_name, module)
self._features = module._features
self._tower_name = tower_name
def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Forward the tower.
Args:
batch (Batch): input batch data.
Return:
embedding (dict): tower output embedding.
"""
return {f"{self._tower_name}_emb": getattr(self, self._tower_name)(batch)}
class TowerWoEGWrapper(nn.Module):
"""Tower without embedding group inference wrapper for jit.script."""
def __init__(self, module: nn.Module, tower_name: str = "user_tower") -> None:
super().__init__()
self.embedding_group = EmbeddingGroup(module._features, [module._feature_group])
setattr(self, tower_name, module)
self._features = module._features
self._tower_name = tower_name
self._group_name = module._group_name
def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Forward the tower.
Args:
batch (Batch): input batch data.
Return:
embedding (dict): tower output embedding.
"""
grouped_features = self.embedding_group(batch)
return {
f"{self._tower_name}_emb": getattr(self, self._tower_name)(
grouped_features[self._group_name]
)
}