tzrec/models/dat.py (174 lines of code) (raw):
# Copyright (c) 2024-2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch._tensor import Tensor
from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.match_model import MatchModel, MatchTower
from tzrec.modules.mlp import MLP
from tzrec.modules.utils import div_no_nan
from tzrec.protos import model_pb2, simi_pb2, tower_pb2
from tzrec.utils.config_util import config_to_kwargs
@torch.fx.wrap
def _update_dict_tensor(
tensor_dict: Dict[str, torch.Tensor],
new_tensor_dict: Optional[Dict[str, Optional[torch.Tensor]]],
) -> None:
if new_tensor_dict:
for k, v in new_tensor_dict.items():
if v is not None:
tensor_dict[k] = v
class DATTower(MatchTower):
"""DAT user/item tower.
Args:
tower_config (Tower): user/item tower config.
output_dim (int): user/item output embedding dimension.
similarity (Similarity): when use COSINE similarity,
will norm the output embedding.
feature_group (FeatureGroupConfig): feature group config.
features (list): list of features.
"""
def __init__(
self,
tower_config: tower_pb2.DATTower,
output_dim: int,
similarity: simi_pb2.Similarity,
feature_groups: List[model_pb2.FeatureGroupConfig],
features: List[BaseFeature],
model_config: model_pb2.ModelConfig,
) -> None:
super().__init__(
tower_config, output_dim, similarity, feature_groups, features, model_config
)
self.init_input()
self._augment_group_name = tower_config.augment_input
# self.augment_embedding_group = EmbeddingGroup(
# augment_features, [augment_feature_group]
# )
tower_feature_in = self.embedding_group.group_total_dim(self._group_name)
tower_augment_feature_in = self.embedding_group.group_total_dim(
self._augment_group_name
)
self.mlp = MLP(
tower_feature_in + tower_augment_feature_in,
**config_to_kwargs(tower_config.mlp),
)
if self._output_dim > 0:
self.output = nn.Linear(self.mlp.output_dim(), output_dim)
def forward(
self, batch: Batch
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Forward the tower.
Args:
batch (Batch): input batch data.
Return:
embedding (dict): tower output embedding.
"""
grouped_features = self.build_input(batch)
input_features = grouped_features[self._group_name]
augmented_feature = grouped_features[self._augment_group_name]
output = self.mlp(torch.concat([input_features, augmented_feature], dim=1))
if self._output_dim > 0:
output = self.output(output)
if self._similarity == simi_pb2.Similarity.COSINE:
output = F.normalize(output, p=2.0, dim=1)
if self.training:
return output, augmented_feature
else:
return output
class DAT(MatchModel):
"""Dual Augmented Two-Towers model.
Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
"""
def __init__(
self,
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels, sample_weights, **kwargs)
name_to_feature_group = {x.group_name: x for x in model_config.feature_groups}
user_group = name_to_feature_group[self._model_config.user_tower.input]
user_augment_group = name_to_feature_group[
self._model_config.user_tower.augment_input
]
item_group = name_to_feature_group[self._model_config.item_tower.input]
item_augment_group = name_to_feature_group[
self._model_config.item_tower.augment_input
]
user_features = self.get_features_in_feature_groups([user_group])
user_augment_features = self.get_features_in_feature_groups(
[user_augment_group]
)
item_features = self.get_features_in_feature_groups([item_group])
item_augment_features = self.get_features_in_feature_groups(
[item_augment_group]
)
self.user_tower = DATTower(
self._model_config.user_tower,
self._model_config.output_dim,
self._model_config.similarity,
[user_group, user_augment_group],
user_features + user_augment_features,
model_config,
)
self.item_tower = DATTower(
self._model_config.item_tower,
self._model_config.output_dim,
self._model_config.similarity,
[item_group, item_augment_group],
item_features + item_augment_features,
model_config,
)
self.amm_u_weight = self._model_config.amm_u_weight
self.amm_i_weight = self._model_config.amm_i_weight
def predict(self, batch: Batch) -> Dict[str, Tensor]:
"""Forward the model.
Args:
batch (Batch): input batch data.
Return:
predictions (dict): a dict of predicted result.
"""
if self.training:
user_tower_emb, user_augment = self.user_tower(batch)
item_tower_emb, item_augment = self.item_tower(batch)
else:
user_tower_emb = self.user_tower(batch)
item_tower_emb = self.item_tower(batch)
user_augment, item_augment = None, None
_update_dict_tensor(
self._loss_collection, self.user_tower.group_variational_dropout_loss
)
_update_dict_tensor(
self._loss_collection, self.item_tower.group_variational_dropout_loss
)
ui_sim = (
self.sim(user_tower_emb, item_tower_emb) / self._model_config.temperature
)
if self.training:
return {
"similarity": ui_sim,
"user_augment": user_augment,
"item_augment": item_augment,
"user_tower_emb": user_tower_emb.detach(),
"item_tower_emb": item_tower_emb.detach(),
}
else:
return {"similarity": ui_sim}
def loss(
self, predictions: Dict[str, torch.Tensor], batch: Batch
) -> Dict[str, torch.Tensor]:
"""Compute the Adaptive-Mimic Mechanism loss."""
losses = super().loss(predictions, batch)
sample_weight = (
batch.sample_weights[self._sample_weight]
if self._sample_weight
else torch.Tensor([1.0])
)
amm_loss = {}
batch_size = predictions["similarity"].size(0)
if "user_augment" in predictions and "item_tower_emb" in predictions:
user_augment = predictions["user_augment"]
item_tower_emb = predictions["item_tower_emb"]
amm_loss_u = self.amm_u_weight * torch.sum(
torch.square(
F.normalize(user_augment, p=2.0, dim=1)
- item_tower_emb[:batch_size]
),
dim=1 if self._sample_weight else (0, 1),
)
amm_loss["amm_loss_u"] = amm_loss_u
if "item_augment" in predictions and "user_tower_emb" in predictions:
item_augment = predictions["item_augment"]
user_tower_emb = predictions["user_tower_emb"]
amm_loss_i = self.amm_i_weight * torch.sum(
torch.square(
F.normalize(item_augment[:batch_size], p=2.0, dim=1)
- user_tower_emb
),
dim=1 if self._sample_weight else (0, 1),
)
amm_loss["amm_loss_i"] = amm_loss_i
if self._sample_weight:
for loss_name in amm_loss.keys():
amm_loss[loss_name] = div_no_nan(
torch.mean(amm_loss[loss_name] * sample_weight),
torch.mean(sample_weight),
)
losses.update(amm_loss)
return losses