tzrec/models/mind.py (237 lines of code) (raw):
# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch._tensor import Tensor
from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.models.match_model import MatchModel, MatchTower
from tzrec.modules.capsule import CapsuleLayer
from tzrec.modules.mlp import MLP
from tzrec.protos import model_pb2, simi_pb2, tower_pb2
from tzrec.utils.config_util import config_to_kwargs
class MINDUserTower(MatchTower):
"""MIND user tower.
Args:
tower_config (Tower): mind user tower config.
output_dim (int): user output embedding dimension.
similarity (Similarity): when use COSINE similarity,
will norm the output embedding.
user_feature_group (FeatureGroupConfig): user feature group config.
hist_feature_group (FeatureGroupConfig): history sequence feature group config.
user_features (list): list of user features.
hist_features (list): list of history sequence features.
model_config (ModelConfig): model config.
"""
def __init__(
self,
tower_config: tower_pb2.MINDUserTower,
output_dim: int,
similarity: simi_pb2.Similarity,
user_feature_group: model_pb2.FeatureGroupConfig,
hist_feature_group: model_pb2.FeatureGroupConfig,
user_features: List[BaseFeature],
hist_features: List[BaseFeature],
model_config: model_pb2.ModelConfig,
) -> None:
super().__init__(
tower_config,
output_dim,
similarity,
[user_feature_group, hist_feature_group],
user_features + hist_features,
model_config,
)
self._hist_group_name = tower_config.history_input
self.init_input()
user_feature_in = self.embedding_group.group_total_dim(self._group_name)
if len(tower_config.user_mlp.hidden_units) > 1:
self.user_mlp = MLP(
in_features=user_feature_in,
hidden_units=tower_config.user_mlp.hidden_units[0:-1],
activation=tower_config.user_mlp.activation,
use_bn=tower_config.user_mlp.use_bn,
dropout_ratio=tower_config.user_mlp.dropout_ratio[0]
if tower_config.user_mlp.dropout_ratio
else None,
)
self.user_mlp_out = nn.Linear(
self.user_mlp.output_dim(), tower_config.user_mlp.hidden_units[-1]
)
else:
self.user_mlp = nn.Linear(
self.user_mlp.user_feature_in, tower_config.user_mlp.hidden_units[-1]
)
self.user_mlp_out = None
hist_feature_dim = self.embedding_group.group_total_dim(
self._hist_group_name + ".sequence"
)
if (
tower_config.hist_seq_mlp
and len(tower_config.hist_seq_mlp.hidden_units) > 1
):
self._hist_seq_mlp = MLP(
in_features=hist_feature_dim,
dim=3,
hidden_units=tower_config.hist_seq_mlp.hidden_units[0:-1],
activation=tower_config.hist_seq_mlp.activation,
use_bn=tower_config.hist_seq_mlp.use_bn,
bias=False,
dropout_ratio=tower_config.hist_seq_mlp.dropout_ratio[0]
if tower_config.hist_seq_mlp.dropout_ratio
else None,
)
self._hist_seq_mlp_out = nn.Linear(
self._hist_seq_mlp.output_dim(),
tower_config.hist_seq_mlp.hidden_units[-1],
bias=False,
)
capsule_input_dim = tower_config.hist_seq_mlp.hidden_units[-1]
elif (
tower_config.hist_seq_mlp
and len(tower_config.hist_seq_mlp.hidden_units) > 0
):
self._hist_seq_mlp = nn.Linear(
hist_feature_dim,
tower_config.hist_seq_mlp.hidden_units[-1],
bias=False,
)
self._hist_seq_mlp_out = None
capsule_input_dim = tower_config.hist_seq_mlp.hidden_units[-1]
else:
self._hist_seq_mlp = None
capsule_input_dim = hist_feature_dim
self._capsule_layer = CapsuleLayer(
capsule_config=tower_config.capsule_config,
input_dim=capsule_input_dim,
)
self._concat_mlp = MLP(
in_features=tower_config.user_mlp.hidden_units[-1]
+ tower_config.capsule_config.high_dim,
dim=3,
**config_to_kwargs(tower_config.concat_mlp),
)
if self._output_dim > 0:
self.output = nn.Linear(
self._concat_mlp.output_dim(), output_dim, bias=False
)
def forward(
self, batch: Batch
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Forward the tower.
Args:
batch (Batch): input batch data.
Returns:
user_interests (Tensor): user interests.
"""
user_feature_dict = self.build_input(batch)
grp_hist_seq = user_feature_dict[self._hist_group_name + ".sequence"]
grp_hist_len = user_feature_dict[self._hist_group_name + ".sequence_length"]
if self.user_mlp_out:
user_feature = self.user_mlp_out(
self.user_mlp(user_feature_dict[self._group_name])
)
else:
user_feature = self.user_mlp(user_feature_dict[self._group_name])
if self._hist_seq_mlp:
if self._hist_seq_mlp_out:
hist_seq_feas = self._hist_seq_mlp_out(self._hist_seq_mlp(grp_hist_seq))
else:
hist_seq_feas = self._hist_seq_mlp(grp_hist_seq)
else:
hist_seq_feas = grp_hist_seq
high_capsules, high_capsules_mask = self._capsule_layer(
hist_seq_feas, grp_hist_len
)
# concatenate user feature and high_capsules
user_feature = torch.unsqueeze(user_feature, dim=1)
user_feature_tile = torch.tile(user_feature, [1, high_capsules.shape[1], 1])
user_interests = torch.cat([user_feature_tile, high_capsules], dim=-1)
user_interests = user_interests * high_capsules_mask.unsqueeze(-1).float()
user_interests = self._concat_mlp(user_interests)
user_interests = user_interests * high_capsules_mask.unsqueeze(-1).float()
if self._output_dim > 0:
user_interests = self.output(user_interests)
if self._similarity == simi_pb2.Similarity.COSINE:
user_interests = F.normalize(user_interests, p=2.0, dim=-1)
if self.is_inference:
return user_interests
else:
return user_interests, high_capsules_mask
class MINDItemTower(MatchTower):
"""MIND item tower.
Args:
tower_config (Tower): mind item tower config.
output_dim (int): item output embedding dimension.
similarity (Similarity): when use COSINE similarity,
will norm the output embedding.
item_feature_group (FeatureGroupConfig): item feature group config.
item_features (list): list of item features.
model_config (ModelConfig): model config.
"""
def __init__(
self,
tower_config: tower_pb2.MINDItemTower,
output_dim: int,
similarity: simi_pb2.Similarity,
item_feature_group: model_pb2.FeatureGroupConfig,
item_features: List[BaseFeature],
model_config: model_pb2.ModelConfig,
) -> None:
super().__init__(
tower_config,
output_dim,
similarity,
[item_feature_group],
item_features,
model_config,
)
self.init_input()
tower_feature_in = self.embedding_group.group_total_dim(self._group_name)
self.mlp = MLP(tower_feature_in, **config_to_kwargs(tower_config.mlp))
if self._output_dim > 0:
self.output = nn.Linear(self.mlp.output_dim(), output_dim)
def forward(self, batch: Batch) -> torch.Tensor:
"""Forward the tower.
Args:
batch (Batch): input batch data.
Return:
item_emb (Tensor): item embedding.
"""
grouped_features = self.build_input(batch)
item_emb = self.mlp(grouped_features[self._group_name])
if self._output_dim > 0:
item_emb = self.output(item_emb)
if self._similarity == simi_pb2.Similarity.COSINE:
item_emb = F.normalize(item_emb, p=2.0, dim=1)
return item_emb
class MIND(MatchModel):
"""MIND model.
Args:
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): list of sample weight names.
"""
def __init__(
self,
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels, sample_weights, **kwargs)
name_to_feature_group = {x.group_name: x for x in model_config.feature_groups}
user_group = name_to_feature_group[self._model_config.user_tower.input]
item_group = name_to_feature_group[self._model_config.item_tower.input]
hist_group = name_to_feature_group[self._model_config.user_tower.history_input]
user_features = self.get_features_in_feature_groups([user_group])
item_features = self.get_features_in_feature_groups([item_group])
hist_features = self.get_features_in_feature_groups([hist_group])
self.user_tower = MINDUserTower(
self._model_config.user_tower,
self._model_config.output_dim,
self._model_config.similarity,
user_group,
hist_group,
user_features,
hist_features,
model_config,
)
self.item_tower = MINDItemTower(
self._model_config.item_tower,
self._model_config.output_dim,
self._model_config.similarity,
item_group,
item_features,
model_config,
)
def label_aware_attention(
self,
user_interests: torch.Tensor,
item_emb: torch.Tensor,
interest_mask: torch.Tensor,
) -> torch.Tensor:
"""Compute label-aware attention for user interests.
Args:
user_interests (Tensor): user interests.
item_emb (Tensor): item embedding.
interest_mask (Tensor): interest mask.
Returns:
user_emb (Tensor): user embedding.
"""
batch_size = user_interests.size(0)
pos_item_emb = item_emb[:batch_size]
simi_pow = self._model_config.simi_pow
interest_weight = torch.einsum("bkd, bd->bk", user_interests, pos_item_emb)
threshold = (interest_mask.float() * 2 - 1) * 1e32
interest_weight = torch.minimum(interest_weight, threshold)
interest_weight = interest_weight.unsqueeze(-1)
interest_weight = interest_weight * simi_pow
interest_weight = torch.nn.functional.softmax(interest_weight, dim=1)
user_emb = torch.sum(torch.multiply(interest_weight, user_interests), dim=1)
return user_emb
def predict(self, batch: Batch) -> Dict[str, Tensor]:
"""Forward the model.
Args:
batch (Batch): input batch data.
Returns:
simi (dict): a dict of predicted result.
"""
if self.is_inference:
user_interests = self.user_tower(batch)
else:
user_interests, interest_mask = self.user_tower(batch)
item_emb = self.item_tower(batch)
user_emb = self.label_aware_attention(
user_interests,
item_emb,
interest_mask, # pyre-ignore [61]
)
# if self._model_config.similarity == simi_pb2.Similarity.COSINE:
# user_emb = F.normalize(user_emb, p=2.0, dim=1)
ui_sim = self.sim(user_emb, item_emb) / self._model_config.temperature
return {"similarity": ui_sim}