tzrec/modules/embedding.py (1,102 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict, defaultdict from typing import Dict, List, NamedTuple, Optional, Tuple, Union import torch from torch import nn from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig from torchrec.modules.embedding_modules import ( EmbeddingBagCollection, EmbeddingCollection, ) from torchrec.modules.mc_embedding_modules import ( ManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection, ) from torchrec.modules.mc_modules import ( ManagedCollisionCollection, ManagedCollisionModule, MCHManagedCollisionModule, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor from tzrec.acc.utils import is_input_tile, is_input_tile_emb from tzrec.datasets.utils import Batch from tzrec.features.feature import BaseFeature from tzrec.modules.dense_embedding_collection import ( DenseEmbeddingCollection, ) from tzrec.modules.sequence import create_seq_encoder from tzrec.protos import model_pb2 from tzrec.protos.model_pb2 import FeatureGroupConfig, SeqGroupConfig from tzrec.utils.fx_util import fx_int_item EMPTY_KJT = KeyedJaggedTensor.empty() torch.fx.wrap(fx_int_item) @torch.fx.wrap def _update_dict_tensor( dict1: Dict[str, torch.Tensor], dict2: Dict[str, torch.Tensor] ) -> Dict[str, torch.Tensor]: for key, value in dict2.items(): if key in dict1: dict1[key] = torch.cat([dict1[key], value], dim=-1) else: dict1[key] = value return dict1 @torch.fx.wrap def _merge_list_of_tensor_dict( list_of_tensor_dict: List[Dict[str, torch.Tensor]], ) -> Dict[str, torch.Tensor]: result = {} for tensor_dict in list_of_tensor_dict: result.update(tensor_dict) return result @torch.fx.wrap def _merge_list_of_dict_of_jt_dict( list_of_dict_of_jt_odict: List[Dict[str, Dict[str, JaggedTensor]]], ) -> Dict[str, Dict[str, JaggedTensor]]: result: Dict[str, Dict[str, JaggedTensor]] = {} for tensor_d_d in list_of_dict_of_jt_odict: result.update(tensor_d_d) return result @torch.fx.wrap def _merge_list_of_jt_dict( list_of_jt_dict: List[Dict[str, JaggedTensor]], ) -> Dict[str, JaggedTensor]: result: Dict[str, JaggedTensor] = {} for jt_dict in list_of_jt_dict: result.update(jt_dict) return result @torch.fx.wrap def _tile_and_combine_dense_kt( user_kt: Optional[KeyedTensor], item_kt: Optional[KeyedTensor], tile_size: int ) -> KeyedTensor: kt_keys: List[str] = [] kt_length_per_key: List[int] = [] kt_values: List[torch.Tensor] = [] if user_kt is not None: kt_keys.extend(user_kt.keys()) kt_length_per_key.extend(user_kt.length_per_key()) kt_values.append(user_kt.values().tile(tile_size, 1)) if item_kt is not None: kt_keys.extend(item_kt.keys()) kt_length_per_key.extend(item_kt.length_per_key()) kt_values.append(item_kt.values()) return KeyedTensor( keys=kt_keys, length_per_key=kt_length_per_key, values=torch.cat(kt_values, dim=1), ) @torch.fx.wrap def _dense_to_jt(t: torch.Tensor) -> JaggedTensor: return JaggedTensor( values=t, lengths=torch.ones( t.shape[0], dtype=torch.int64, device=t.device, ), ) class EmbeddingGroup(nn.Module): """Applies embedding lookup transformation for feature group. Args: features (list): list of features. feature_groups (list): list of feature group config. wide_embedding_dim (int, optional): wide group feature embedding dim. device (torch.device): embedding device, default is meta. """ def __init__( self, features: List[BaseFeature], feature_groups: List[FeatureGroupConfig], wide_embedding_dim: Optional[int] = None, device: Optional[torch.device] = None, ) -> None: super().__init__() if device is None: device = torch.device("meta") self._features = features self._feature_groups = feature_groups self._name_to_feature = {x.name: x for x in features} self._name_to_feature_group = {x.group_name: x for x in feature_groups} self.emb_impls = nn.ModuleDict() self.seq_emb_impls = nn.ModuleDict() self.seq_encoders = nn.ModuleDict() self._impl_key_to_feat_groups = defaultdict(list) self._impl_key_to_seq_groups = defaultdict(list) self._group_name_to_impl_key = dict() self._group_name_to_seq_encoder_configs = defaultdict(list) self._grouped_features_keys = list() seq_group_names = [] for feature_group in feature_groups: group_name = feature_group.group_name self._inspect_and_supplement_feature_group(feature_group, seq_group_names) # self._add_feature_group_sign_for_sequence_groups(feature_group) features_data_group = defaultdict(list) for feature_name in feature_group.feature_names: feature = self._name_to_feature[feature_name] features_data_group[feature.data_group].append(feature_name) for sequence_group in feature_group.sequence_groups: for feature_name in sequence_group.feature_names: feature = self._name_to_feature[feature_name] features_data_group[feature.data_group].append(feature_name) if len(features_data_group) > 1: error_info = [",".join(v) for v in features_data_group.values()] raise ValueError( f"Feature {error_info} should not belong to same feature group." ) impl_key = list(features_data_group.keys())[0] self._group_name_to_impl_key[group_name] = impl_key if feature_group.group_type == model_pb2.SEQUENCE: self._impl_key_to_seq_groups[impl_key].append(feature_group) self._grouped_features_keys.append(group_name + ".query") self._grouped_features_keys.append(group_name + ".sequence") self._grouped_features_keys.append(group_name + ".sequence_length") else: self._impl_key_to_feat_groups[impl_key].append(feature_group) if len(feature_group.sequence_groups) > 0: self._impl_key_to_seq_groups[impl_key].extend( list(feature_group.sequence_groups) ) if len(feature_group.sequence_encoders) > 0: self._group_name_to_seq_encoder_configs[group_name] = list( feature_group.sequence_encoders ) self._grouped_features_keys.append(group_name) for k, v in self._impl_key_to_feat_groups.items(): self.emb_impls[k] = EmbeddingGroupImpl( features, feature_groups=v, wide_embedding_dim=wide_embedding_dim, device=device, ) for k, v in self._impl_key_to_seq_groups.items(): self.seq_emb_impls[k] = SequenceEmbeddingGroupImpl( features, feature_groups=v, device=device ) self._group_name_to_seq_encoders = nn.ModuleDict() for ( group_name, seq_encoder_configs, ) in self._group_name_to_seq_encoder_configs.items(): impl_key = self._group_name_to_impl_key[group_name] seq_emb = self.seq_emb_impls[impl_key] group_seq_encoders = nn.ModuleList() for seq_encoder_config in seq_encoder_configs: seq_encoder = create_seq_encoder( seq_encoder_config, seq_emb.all_group_total_dim() ) group_seq_encoders.append(seq_encoder) self._group_name_to_seq_encoders[group_name] = group_seq_encoders self._group_feature_dims = OrderedDict() for feature_group in feature_groups: group_name = feature_group.group_name if feature_group.group_type != model_pb2.SEQUENCE: feature_dim = OrderedDict() impl_key = self._group_name_to_impl_key[group_name] feature_emb = self.emb_impls[impl_key] feature_dim.update(feature_emb.group_feature_dims(group_name)) if group_name in self._group_name_to_seq_encoders: seq_encoders = self._group_name_to_seq_encoders[group_name] for i, seq_encoder in enumerate(seq_encoders): feature_dim[f"{group_name}_seq_encoder_{i}"] = ( seq_encoder.output_dim() ) self._group_feature_dims[group_name] = feature_dim self._grouped_features_keys.sort() def grouped_features_keys(self) -> List[str]: """grouped_features_keys.""" return self._grouped_features_keys def _inspect_and_supplement_feature_group( self, feature_group: FeatureGroupConfig, seq_group_names: List[str] ) -> None: """Inspect feature group sequence_groups and sequence_encoders.""" group_name = feature_group.group_name sequence_groups = list(feature_group.sequence_groups) sequence_encoders = list(feature_group.sequence_encoders) is_deep = feature_group.group_type == model_pb2.DEEP if is_deep: if len(sequence_groups) == 0 and sequence_encoders == 0: return elif len(sequence_groups) > 0 and sequence_encoders == 0: raise ValueError( f"{group_name} group has sequence_groups,but no sequence_encoders " ) elif len(sequence_groups) == 0 and len(sequence_encoders) > 0: raise ValueError( f"{group_name} group has sequence_encoders,but no sequence_groups " ) if len(sequence_groups) > 1: for sequence_group in sequence_groups: if not sequence_group.HasField("group_name"): raise ValueError( f"{group_name} has many sequence_groups, " f"every sequence_group must has group_name" ) elif len(sequence_groups) == 1 and not sequence_groups[0].HasField( "group_name" ): sequence_groups[0].group_name = group_name for sequence_group in sequence_groups: if sequence_group.group_name in seq_group_names: raise ValueError( f"has repeat sequences groups_name: {sequence_group.group_name}" ) else: seq_group_names.append(sequence_group.group_name) group_has_encoder = { sequence_group.group_name: False for sequence_group in sequence_groups } for sequence_encoder in sequence_encoders: seq_type = sequence_encoder.WhichOneof("seq_module") seq_config = getattr(sequence_encoder, seq_type) if not seq_config.HasField("input") and len(sequence_groups) == 1: seq_config.input = sequence_groups[0].group_name if not seq_config.HasField("input"): raise ValueError( f"{group_name} group has multi sequence_groups, " f"so sequence_encoders must has input" ) if seq_config.input not in group_has_encoder: raise ValueError( f"{group_name} sequence_encoder input {seq_config.input} " f"not in sequence_groups" ) else: group_has_encoder[seq_config.input] = True for k, v in group_has_encoder.items(): if not v: raise ValueError( f"{group_name} sequence_groups {k} not has seq_encoder" ) else: if len(sequence_groups) > 0 or len(sequence_encoders) > 0: raise ValueError( f"{group_name} group group_type is not DEEP, " f"sequence_groups and sequence_encoders must configured in DEEP" ) def group_names(self) -> List[str]: """Feature group names.""" return list(self._name_to_feature_group.keys()) def group_dims(self, group_name: str) -> List[int]: """Output dimension of each feature in a feature group. Args: group_name (str): feature group name, when group type is sequence, should use {group_name}.query or {group_name}.sequence. Return: group_dims (list): output dimension of each feature. """ true_name = group_name.split(".")[0] if "." in group_name else group_name feature_group = self._name_to_feature_group[true_name] impl_key = self._group_name_to_impl_key[true_name] if feature_group.group_type == model_pb2.SEQUENCE: return self.seq_emb_impls[impl_key].group_dims(group_name) else: dims = self.emb_impls[impl_key].group_dims(group_name) if group_name in self._group_name_to_seq_encoders: for seq_encoder in self._group_name_to_seq_encoders[group_name]: dims.append(seq_encoder.output_dim()) return dims def group_total_dim(self, group_name: str) -> int: """Total output dimension of a feature group. Args: group_name (str): feature group name, when group type is sequence, should use {group_name}.query or {group_name}.sequence. Return: total_dim (int): total dimension of feature group. """ true_name = group_name.split(".")[0] if "." in group_name else group_name feature_group = self._name_to_feature_group[true_name] impl_key = self._group_name_to_impl_key[true_name] if feature_group.group_type == model_pb2.SEQUENCE: return self.seq_emb_impls[impl_key].group_total_dim(group_name) else: return sum(self._group_feature_dims[group_name].values()) def group_feature_dims(self, group_name: str) -> Dict[str, int]: """Every feature group each feature dim.""" true_name = group_name.split(".")[0] if "." in group_name else group_name feature_group = self._name_to_feature_group[true_name] if feature_group.group_type == model_pb2.SEQUENCE: raise ValueError("not support sequence group") return self._group_feature_dims[group_name] def has_group(self, group_name: str) -> bool: """Check the feature group exist or not.""" true_name = group_name.split(".")[0] if "." in group_name else group_name return true_name in self._name_to_feature_group.keys() def forward( self, batch: Batch, ) -> Dict[str, torch.Tensor]: """Forward the module. Args: batch (Batch): a instance of Batch with features. Returns: group_features (dict): dict of feature_group to embedded tensor. """ result_dicts = [] need_input_tile = is_input_tile() if need_input_tile: emb_keys = list(self.emb_impls.keys()) seq_emb_keys = list(self.seq_emb_impls.keys()) unique_keys = list(set(emb_keys + seq_emb_keys)) # tile user dense feat & combine item dense feat, when has user dense feat for key in unique_keys: user_kt = batch.dense_features.get(key + "_user", None) if user_kt is not None: item_kt = batch.dense_features.get(key, None) batch.dense_features[key] = _tile_and_combine_dense_kt( user_kt, item_kt, batch.tile_size ) for key, emb_impl in self.emb_impls.items(): sparse_feat_kjt = None sparse_feat_kjt_user = None dense_feat_kt = None if emb_impl.has_dense: dense_feat_kt = batch.dense_features[key] if emb_impl.has_sparse or emb_impl.has_mc_sparse: sparse_feat_kjt = batch.sparse_features[key] if emb_impl.has_sparse_user or emb_impl.has_mc_sparse_user: sparse_feat_kjt_user = batch.sparse_features[key + "_user"] result_dicts.append( emb_impl( sparse_feat_kjt, dense_feat_kt, sparse_feat_kjt_user, batch.tile_size, ) ) for key, seq_emb_impl in self.seq_emb_impls.items(): sparse_feat_kjt = None sparse_feat_kjt_user = None dense_feat_kt = None sequence_mulval_length_kjt = None sequence_mulval_length_kjt_user = None if seq_emb_impl.has_dense: dense_feat_kt = batch.dense_features[key] if seq_emb_impl.has_sparse or seq_emb_impl.has_mc_sparse: sparse_feat_kjt = batch.sparse_features[key] if seq_emb_impl.has_mulval_seq: sequence_mulval_length_kjt = batch.sequence_mulval_lengths[key] if seq_emb_impl.has_sparse_user or seq_emb_impl.has_mc_sparse_user: sparse_feat_kjt_user = batch.sparse_features[key + "_user"] if seq_emb_impl.has_mulval_seq_user: sequence_mulval_length_kjt_user = batch.sequence_mulval_lengths[ key + "_user" ] result_dicts.append( seq_emb_impl( sparse_feat_kjt, dense_feat_kt, batch.sequence_dense_features, sequence_mulval_length_kjt, sparse_feat_kjt_user, sequence_mulval_length_kjt_user, batch.tile_size, ) ) result = _merge_list_of_tensor_dict(result_dicts) seq_feature_dict = {} for group_name, seq_encoders in self._group_name_to_seq_encoders.items(): new_feature = [] for seq_encoder in seq_encoders: new_feature.append(seq_encoder(result)) seq_feature_dict[group_name] = torch.cat(new_feature, dim=-1) return _update_dict_tensor(result, seq_feature_dict) def predict( self, batch: Batch, ) -> List[torch.Tensor]: """Predict embedding module and return a list of grouped embedding features.""" grouped_features = self.forward(batch) values_list = [] for key in self._grouped_features_keys: values_list.append(grouped_features[key]) return values_list def _add_embedding_bag_config( emb_bag_configs: Dict[str, EmbeddingBagConfig], emb_bag_config: EmbeddingBagConfig ) -> None: """Add embedding bag config to a dict of embedding bag config. Args: emb_bag_configs(Dict[str, EmbeddingBagConfig]): a dict contains emb_bag_configs emb_bag_config(EmbeddingBagConfig): an instance of EmbeddingBagConfig """ if emb_bag_config.name in emb_bag_configs: existed_emb_bag_config = emb_bag_configs[emb_bag_config.name] assert ( emb_bag_config.num_embeddings == existed_emb_bag_config.num_embeddings and emb_bag_config.embedding_dim == existed_emb_bag_config.embedding_dim and emb_bag_config.pooling == existed_emb_bag_config.pooling and repr(emb_bag_config.init_fn) == repr(existed_emb_bag_config.init_fn) ), ( f"there is a mismatch between {emb_bag_config} and " f"{existed_emb_bag_config}, can not share embedding." ) for feature_name in emb_bag_config.feature_names: if feature_name not in existed_emb_bag_config.feature_names: existed_emb_bag_config.feature_names.append(feature_name) else: emb_bag_configs[emb_bag_config.name] = emb_bag_config def _add_embedding_config( emb_configs: Dict[str, EmbeddingConfig], emb_config: EmbeddingConfig ) -> None: """Add embedding config to a dict of embedding config. Args: emb_configs(Dict[str, EmbeddingConfig]): a dict contains emb_configs emb_config(EmbeddingConfig): an instance of EmbeddingConfig """ if emb_config.name in emb_configs: existed_emb_config = emb_configs[emb_config.name] assert ( emb_config.num_embeddings == existed_emb_config.num_embeddings and emb_config.embedding_dim == existed_emb_config.embedding_dim and repr(emb_config.init_fn) == repr(existed_emb_config.init_fn) ), ( f"there is a mismatch between {emb_config} and " f"{existed_emb_config}, can not share embedding." ) for feature_name in emb_config.feature_names: if feature_name not in existed_emb_config.feature_names: existed_emb_config.feature_names.append(feature_name) else: emb_configs[emb_config.name] = emb_config def _add_mc_module( mc_modules: Dict[str, ManagedCollisionModule], emb_name: str, mc_module: ManagedCollisionModule, ) -> None: """Add ManagedCollisionModule to a dict of ManagedCollisionModule. Args: mc_modules(Dict[str, ManagedCollisionModule]): a dict of ManagedCollisionModule. emb_name(str): embedding_name. mc_module(ManagedCollisionModule): an instance of ManagedCollisionModule. """ if emb_name in mc_modules: existed_mc_module = mc_modules[emb_name] if isinstance(mc_module, MCHManagedCollisionModule): assert isinstance(existed_mc_module, MCHManagedCollisionModule) assert mc_module._zch_size == existed_mc_module._zch_size assert mc_module._eviction_interval == existed_mc_module._eviction_interval assert repr(mc_module._eviction_policy) == repr(mc_module._eviction_policy) mc_modules[emb_name] = mc_module class EmbeddingGroupImpl(nn.Module): """Applies embedding lookup transformation for feature group. Args: features (list): list of features. feature_groups (list): list of feature group config. wide_embedding_dim (int, optional): wide group feature embedding dim. device (torch.device): embedding device, default is meta. """ def __init__( self, features: List[BaseFeature], feature_groups: List[FeatureGroupConfig], wide_embedding_dim: Optional[int] = None, device: Optional[torch.device] = None, ) -> None: super().__init__() if device is None: device = torch.device("meta") name_to_feature = {x.name: x for x in features} need_input_tile_emb = is_input_tile_emb() emb_bag_configs = OrderedDict() mc_emb_bag_configs = OrderedDict() mc_modules = OrderedDict() dense_embedding_configs = [] self.has_sparse = False self.has_mc_sparse = False self.has_dense = False self.has_dense_embedding = False # for sparse input-tile-emb emb_bag_configs_user = OrderedDict() mc_emb_bag_configs_user = OrderedDict() mc_modules_user = OrderedDict() self.has_sparse_user = False self.has_mc_sparse_user = False self._group_to_feature_names = OrderedDict() self._group_to_shared_feature_names = OrderedDict() self._group_total_dim = dict() self._group_feature_output_dims = dict() self._group_dense_feature_names = dict() self._group_dense_embedding_feature_names = dict() feat_to_group_to_emb_name = defaultdict(dict) for feature_group in feature_groups: group_name = feature_group.group_name for feature_name in feature_group.feature_names: feature = name_to_feature[feature_name] if feature.is_sparse: emb_bag_config = feature.emb_bag_config # pyre-ignore [16] emb_name = emb_bag_config.name if feature_group.group_type == model_pb2.WIDE: emb_name = emb_name + "_wide" feat_to_group_to_emb_name[feature_name][group_name] = emb_name shared_feature_flag = dict() for feature_name, group_to_emb_name in feat_to_group_to_emb_name.items(): if len(set(group_to_emb_name.values())) > 1: shared_feature_flag[feature_name] = True else: shared_feature_flag[feature_name] = False non_emb_dense_feature_to_dim = OrderedDict() for feature_group in feature_groups: total_dim = 0 feature_output_dims = OrderedDict() group_name = feature_group.group_name feature_names = list(feature_group.feature_names) shared_feature_names = [] is_wide = feature_group.group_type == model_pb2.WIDE for name in feature_names: shared_name = name feature = name_to_feature[name] if feature.is_sparse: output_dim = feature.output_dim emb_bag_config = feature.emb_bag_config mc_module = feature.mc_module(device) assert emb_bag_config is not None if is_wide: # TODO(hongsheng.jhs): change to embedding_dim to 1 # when fbgemm support embedding_dim=1 emb_bag_config.embedding_dim = output_dim = ( wide_embedding_dim or 4 ) # we may modify ebc name at feat_to_group_to_emb_name, e.g., wide emb_bag_config.name = feat_to_group_to_emb_name[name][group_name] if need_input_tile_emb and feature.is_user_feat: _add_embedding_bag_config( emb_bag_configs=mc_emb_bag_configs_user if mc_module else emb_bag_configs_user, emb_bag_config=emb_bag_config, ) if mc_module: _add_mc_module( mc_modules_user, emb_bag_config.name, mc_module ) self.has_mc_sparse_user = True else: self.has_sparse_user = True else: _add_embedding_bag_config( emb_bag_configs=mc_emb_bag_configs if mc_module else emb_bag_configs, emb_bag_config=emb_bag_config, ) if mc_module: _add_mc_module(mc_modules, emb_bag_config.name, mc_module) self.has_mc_sparse = True else: self.has_sparse = True if shared_feature_flag[name]: shared_name = shared_name + "@" + emb_bag_config.name else: output_dim = feature.output_dim if is_wide: raise ValueError( f"dense feature [{name}] should not be configured in " "wide group." ) else: self.has_dense = True if feature.dense_emb_config: self.has_dense_embedding = True conf_obj = feature.dense_emb_config dense_embedding_configs.append(conf_obj) else: non_emb_dense_feature_to_dim[name] = output_dim total_dim += output_dim feature_output_dims[name] = output_dim shared_feature_names.append(shared_name) self._group_to_feature_names[group_name] = feature_names if len(shared_feature_names) > 0: self._group_to_shared_feature_names[group_name] = shared_feature_names self._group_total_dim[group_name] = total_dim self._group_feature_output_dims[group_name] = feature_output_dims self.ebc = EmbeddingBagCollection(list(emb_bag_configs.values()), device=device) if self.has_mc_sparse: self.mc_ebc = ManagedCollisionEmbeddingBagCollection( EmbeddingBagCollection( list(mc_emb_bag_configs.values()), device=device ), ManagedCollisionCollection( mc_modules, list(mc_emb_bag_configs.values()) ), ) if self.has_dense_embedding: self.dense_ec = DenseEmbeddingCollection( dense_embedding_configs, device=device, raw_dense_feature_to_dim=non_emb_dense_feature_to_dim, ) if need_input_tile_emb: self.ebc_user = EmbeddingBagCollection( list(emb_bag_configs_user.values()), device=device ) if self.has_mc_sparse_user: self.mc_ebc_user = ManagedCollisionEmbeddingBagCollection( EmbeddingBagCollection( list(mc_emb_bag_configs_user.values()), device=device ), ManagedCollisionCollection( mc_modules_user, list(mc_emb_bag_configs_user.values()) ), ) def group_dims(self, group_name: str) -> List[int]: """Output dimension of each feature in a feature group.""" return list(self._group_feature_output_dims[group_name].values()) def group_feature_dims(self, group_name: str) -> Dict[str, int]: """Output dimension of each feature in a feature group.""" return self._group_feature_output_dims[group_name] def group_total_dim(self, group_name: str) -> int: """Total output dimension of a feature group.""" return self._group_total_dim[group_name] def forward( self, sparse_feature: KeyedJaggedTensor, dense_feature: KeyedTensor, sparse_feature_user: KeyedJaggedTensor, tile_size: int = -1, ) -> Dict[str, torch.Tensor]: """Forward the module. Args: sparse_feature (KeyedJaggedTensor): sparse id feature. dense_feature (dense_feature): dense feature. sparse_feature_user (KeyedJaggedTensor): user-side sparse feature with batch_size=1, when use INPUT_TILE=3. tile_size: size for user-side feature input tile. Returns: group_features (dict): dict of feature_group to embedded tensor. """ kts: List[KeyedTensor] = [] if self.has_sparse: kts.append(self.ebc(sparse_feature)) if self.has_mc_sparse: kts.append(self.mc_ebc(sparse_feature)[0]) # do user-side embedding input-tile if self.has_sparse_user: keyed_tensor_user = self.ebc_user(sparse_feature_user) values_tile = keyed_tensor_user.values().tile(tile_size, 1) keyed_tensor_user_tile = KeyedTensor( keys=keyed_tensor_user.keys(), length_per_key=keyed_tensor_user.length_per_key(), values=values_tile, ) kts.append(keyed_tensor_user_tile) # do user-side mc embedding input-tile if self.has_mc_sparse_user: keyed_tensor_user = self.mc_ebc_user(sparse_feature_user)[0] values_tile = keyed_tensor_user.values().tile(tile_size, 1) keyed_tensor_user_tile = KeyedTensor( keys=keyed_tensor_user.keys(), length_per_key=keyed_tensor_user.length_per_key(), values=values_tile, ) kts.append(keyed_tensor_user_tile) if self.has_dense: if self.has_dense_embedding: kts.append(self.dense_ec(dense_feature)) else: kts.append(dense_feature) group_tensors = KeyedTensor.regroup_as_dict( kts, list(self._group_to_shared_feature_names.values()), list(self._group_to_shared_feature_names.keys()), ) return group_tensors class SequenceEmbeddingGroup(nn.Module): """Applies embedding lookup transformation for feature group. Args: features (list): list of features. feature_groups (list): list of feature group config. wide_embedding_dim (int, optional): wide group feature embedding dim. device (torch.device): embedding device, default is meta. """ def __init__( self, features: List[BaseFeature], feature_groups: List[FeatureGroupConfig], device: Optional[torch.device] = None, ) -> None: super().__init__() if device is None: device = torch.device("meta") self._features = features self._feature_groups = feature_groups self._name_to_feature = {x.name: x for x in features} self._name_to_feature_group = {x.group_name: x for x in feature_groups} self.seq_emb_impls = nn.ModuleDict() self._impl_key_to_seq_groups = defaultdict(list) self._group_name_to_impl_key = dict() for feature_group in feature_groups: assert feature_group.group_type == model_pb2.SEQUENCE group_name = feature_group.group_name features_data_group = defaultdict(list) for feature_name in feature_group.feature_names: feature = self._name_to_feature[feature_name] features_data_group[feature.data_group].append(feature_name) if len(features_data_group) > 1: error_info = [",".join(v) for v in features_data_group.values()] raise ValueError( f"Feature {error_info} should not belong to same feature group." ) impl_key = list(features_data_group.keys())[0] self._group_name_to_impl_key[group_name] = impl_key self._impl_key_to_seq_groups[impl_key].append(feature_group) for k, v in self._impl_key_to_seq_groups.items(): self.seq_emb_impls[k] = SequenceEmbeddingGroupImpl( features, feature_groups=v, device=device ) def group_names(self) -> List[str]: """Feature group names.""" return list(self._name_to_feature_group.keys()) def group_dims(self, group_name: str) -> List[int]: """Output dimension of each feature in a feature group. Args: group_name (str): feature group name, when group type is sequence, should use {group_name}.query or {group_name}.sequence. Return: group_dims (list): output dimension of each feature. """ true_name = group_name.split(".")[0] if "." in group_name else group_name impl_key = self._group_name_to_impl_key[true_name] return self.seq_emb_impls[impl_key].group_dims(group_name) def group_total_dim(self, group_name: str) -> int: """Total output dimension of a feature group. Args: group_name (str): feature group name, when group type is sequence, should use {group_name}.query or {group_name}.sequence. Return: total_dim (int): total dimension of feature group. """ true_name = group_name.split(".")[0] if "." in group_name else group_name impl_key = self._group_name_to_impl_key[true_name] return self.seq_emb_impls[impl_key].group_total_dim(group_name) def has_group(self, group_name: str) -> bool: """Check the feature group exist or not.""" true_name = group_name.split(".")[0] if "." in group_name else group_name return true_name in self._name_to_feature_group.keys() def forward( self, batch: Batch, ) -> Dict[str, torch.Tensor]: """Forward the module. Args: batch (Batch): a instance of Batch with features. Returns: group_features (dict): dict of feature_group to embedded tensor. """ result_dicts = [] need_input_tile = is_input_tile() if need_input_tile: unique_keys = list(self.seq_emb_impls.keys()) # tile user dense feat & combine item dense feat, when has user dense feat for key in unique_keys: user_kt = batch.dense_features.get(key + "_user", None) if user_kt is not None: item_kt = batch.dense_features.get(key, None) batch.dense_features[key] = _tile_and_combine_dense_kt( user_kt, item_kt, batch.tile_size ) for key, seq_emb_impl in self.seq_emb_impls.items(): sparse_feat_kjt = None sparse_feat_kjt_user = None dense_feat_kt = None sequence_mulval_length_kjt = None sequence_mulval_length_kjt_user = None if seq_emb_impl.has_dense: dense_feat_kt = batch.dense_features[key] if seq_emb_impl.has_sparse or seq_emb_impl.has_mc_sparse: sparse_feat_kjt = batch.sparse_features[key] if seq_emb_impl.has_mulval_seq: sequence_mulval_length_kjt = batch.sequence_mulval_lengths[key] if seq_emb_impl.has_sparse_user or seq_emb_impl.has_mc_sparse_user: sparse_feat_kjt_user = batch.sparse_features[key + "_user"] if seq_emb_impl.has_mulval_seq_user: sequence_mulval_length_kjt_user = batch.sequence_mulval_lengths[ key + "_user" ] result_dicts.append( seq_emb_impl( sparse_feat_kjt, dense_feat_kt, batch.sequence_dense_features, sequence_mulval_length_kjt, sparse_feat_kjt_user, sequence_mulval_length_kjt_user, batch.tile_size, ) ) return _merge_list_of_tensor_dict(result_dicts) def jagged_forward( self, batch: Batch, ) -> Dict[str, OrderedDict[str, JaggedTensor]]: """Forward the module. Args: batch (Batch): a instance of Batch with features. Returns: group_features (dict): dict of feature_group to dict of embedded tensor. """ need_input_tile = is_input_tile() assert not need_input_tile, "jagged forward not support INPUT_TILE now." result_dicts = [] for key, seq_emb_impl in self.seq_emb_impls.items(): sparse_feat_kjt = None dense_feat_kt = None sequence_mulval_length_kjt = None if seq_emb_impl.has_dense: dense_feat_kt = batch.dense_features[key] if seq_emb_impl.has_sparse or seq_emb_impl.has_mc_sparse: sparse_feat_kjt = batch.sparse_features[key] if seq_emb_impl.has_mulval_seq: sequence_mulval_length_kjt = batch.sequence_mulval_lengths[key] result_dicts.append( seq_emb_impl.jagged_forward( sparse_feat_kjt, dense_feat_kt, batch.sequence_dense_features, sequence_mulval_length_kjt, ) ) result = _merge_list_of_dict_of_jt_dict(result_dicts) return result class _SequenceEmbeddingInfo(NamedTuple): """One Embedding info in SequenceGroup.""" name: str raw_name: str is_sparse: bool pooling: str value_dim: int is_user: bool is_sequence: bool class SequenceEmbeddingGroupImpl(nn.Module): """Applies embedding lookup transformation for sequence feature group. Args: features (list): list of features. feature_groups (list): list of feature group config or seq group config. device (torch.device): embedding device, default is meta. """ def __init__( self, features: List[BaseFeature], feature_groups: List[Union[FeatureGroupConfig, SeqGroupConfig]], device: Optional[torch.device] = None, ) -> None: super().__init__() if device is None: device = torch.device("meta") name_to_feature = {x.name: x for x in features} need_input_tile = is_input_tile() need_input_tile_emb = is_input_tile_emb() dim_to_emb_configs = defaultdict(OrderedDict) dim_to_mc_emb_configs = defaultdict(OrderedDict) dim_to_mc_modules = defaultdict(OrderedDict) self.has_sparse = False self.has_mc_sparse = False self.has_dense = False self.has_sequence_dense = False self.has_mulval_seq = False # for sparse input-tile-emb dim_to_emb_configs_user = defaultdict(OrderedDict) dim_to_mc_emb_configs_user = defaultdict(OrderedDict) dim_to_mc_modules_user = defaultdict(OrderedDict) self.has_sparse_user = False self.has_mc_sparse_user = False self.has_mulval_seq_user = False self._group_to_shared_query = OrderedDict() self._group_to_shared_sequence = OrderedDict() self._group_to_shared_feature = OrderedDict() self._group_total_dim = dict() self._group_output_dims = dict() feat_to_group_to_emb_name = defaultdict(dict) for feature_group in feature_groups: group_name = feature_group.group_name for feature_name in feature_group.feature_names: feature = name_to_feature[feature_name] if feature.is_sparse: emb_config = feature.emb_config # pyre-ignore [16] emb_name = emb_config.name feat_to_group_to_emb_name[feature_name][group_name] = emb_name shared_feature_flag = dict() for feature_name, group_to_emb_name in feat_to_group_to_emb_name.items(): if len(set(group_to_emb_name.values())) > 1: shared_feature_flag[feature_name] = True else: shared_feature_flag[feature_name] = False for feature_group in feature_groups: query_dim = 0 sequence_dim = 0 query_dims = [] sequence_dims = [] output_dims = [] group_name = feature_group.group_name feature_names = list(feature_group.feature_names) shared_query = [] shared_sequence = [] shared_feature = [] for name in feature_names: shared_name = name feature = name_to_feature[name] if feature.is_sparse: output_dim = feature.output_dim emb_config = feature.emb_config mc_module = feature.mc_module(device) assert emb_config is not None # we may/could modify ec name at feat_to_group_to_emb_name emb_config.name = feat_to_group_to_emb_name[name][group_name] embedding_dim = emb_config.embedding_dim if need_input_tile_emb and feature.is_user_feat: emb_configs = ( dim_to_mc_emb_configs_user[embedding_dim] if mc_module else dim_to_emb_configs_user[embedding_dim] ) _add_embedding_config( emb_configs=emb_configs, emb_config=emb_config, ) if mc_module: _add_mc_module( dim_to_mc_modules_user[embedding_dim], emb_config.name, mc_module, ) self.has_mc_sparse_user = True else: self.has_sparse_user = True if feature.is_sequence and feature.value_dim != 1: self.has_mulval_seq_user = True else: emb_configs = ( dim_to_mc_emb_configs[embedding_dim] if mc_module else dim_to_emb_configs[embedding_dim] ) _add_embedding_config( emb_configs=emb_configs, emb_config=emb_config, ) if mc_module: _add_mc_module( dim_to_mc_modules[embedding_dim], emb_config.name, mc_module, ) self.has_mc_sparse = True else: self.has_sparse = True if feature.is_sequence and feature.value_dim != 1: self.has_mulval_seq = True if shared_feature_flag[name]: shared_name = shared_name + "@" + emb_config.name else: output_dim = feature.output_dim if feature.is_sequence: self.has_sequence_dense = True else: self.has_dense = True is_user_feat = feature.is_user_feat if need_input_tile else False shared_info = _SequenceEmbeddingInfo( name=shared_name, raw_name=name, is_sparse=feature.is_sparse, pooling=feature.pooling_type.value.lower(), value_dim=feature.value_dim, is_user=is_user_feat, is_sequence=feature.is_sequence, ) if feature.is_sequence: shared_sequence.append(shared_info) sequence_dim += output_dim sequence_dims.append(output_dim) else: shared_query.append(shared_info) query_dim += output_dim query_dims.append(output_dim) shared_feature.append(shared_info) output_dims.append(output_dim) self._group_to_shared_query[group_name] = shared_query self._group_to_shared_sequence[group_name] = shared_sequence self._group_to_shared_feature[group_name] = shared_feature self._group_total_dim[f"{group_name}.query"] = query_dim self._group_total_dim[f"{group_name}.sequence"] = sequence_dim self._group_output_dims[f"{group_name}.query"] = query_dims self._group_output_dims[f"{group_name}.sequence"] = sequence_dims self._group_output_dims[group_name] = output_dims self.ec_list = nn.ModuleList() for _, emb_configs in dim_to_emb_configs.items(): self.ec_list.append( EmbeddingCollection(list(emb_configs.values()), device=device) ) self.mc_ec_list = nn.ModuleList() for k, emb_configs in dim_to_mc_emb_configs.items(): self.mc_ec_list.append( ManagedCollisionEmbeddingCollection( EmbeddingCollection(list(emb_configs.values()), device=device), ManagedCollisionCollection( dim_to_mc_modules[k], list(emb_configs.values()) ), ) ) if need_input_tile_emb: self.ec_list_user = nn.ModuleList() for _, emb_configs in dim_to_emb_configs_user.items(): self.ec_list_user.append( EmbeddingCollection(list(emb_configs.values()), device=device) ) self.mc_ec_list_user = nn.ModuleList() for k, emb_configs in dim_to_mc_emb_configs_user.items(): self.mc_ec_list_user.append( ManagedCollisionEmbeddingCollection( EmbeddingCollection(list(emb_configs.values()), device=device), ManagedCollisionCollection( dim_to_mc_modules_user[k], list(emb_configs.values()) ), ) ) def group_dims(self, group_name: str) -> List[int]: """Output dimension of each feature in a feature group.""" return self._group_output_dims[group_name] def group_total_dim(self, group_name: str) -> int: """Total output dimension of a feature group.""" if "." in group_name: return self._group_total_dim[group_name] else: return ( self._group_total_dim[f"{group_name}.query"] + self._group_total_dim[f"{group_name}.sequence"] ) def all_group_total_dim(self) -> Dict[str, int]: """Total output dimension of all feature group.""" return self._group_total_dim def has_group(self, group_name: str) -> bool: """Check the feature group exist or not.""" true_name = group_name.split(".")[0] if "." in group_name else group_name return true_name in self._group_output_dims.keys() def _forward_impl( self, sparse_feature: KeyedJaggedTensor, dense_feature: KeyedTensor, sequence_dense_features: Dict[str, JaggedTensor], sequence_mulval_lengths: KeyedJaggedTensor, sparse_feature_user: KeyedJaggedTensor, sequence_mulval_lengths_user: KeyedJaggedTensor, ) -> Tuple[Dict[str, JaggedTensor], Dict[str, torch.Tensor]]: sparse_jt_dict_list: List[Dict[str, JaggedTensor]] = [] seq_mulval_length_jt_dict_list: List[Dict[str, JaggedTensor]] = [] dense_t_dict: Dict[str, torch.Tensor] = {} if self.has_sparse: for ec in self.ec_list: sparse_jt_dict_list.append(ec(sparse_feature)) if self.has_mc_sparse: for ec in self.mc_ec_list: sparse_jt_dict_list.append(ec(sparse_feature)[0]) if self.has_mulval_seq: seq_mulval_length_jt_dict_list.append(sequence_mulval_lengths.to_dict()) if self.has_sparse_user: for ec in self.ec_list_user: sparse_jt_dict_list.append(ec(sparse_feature_user)) if self.has_mc_sparse_user: for ec in self.mc_ec_list_user: sparse_jt_dict_list.append(ec(sparse_feature_user)[0]) if self.has_mulval_seq_user: seq_mulval_length_jt_dict_list.append( sequence_mulval_lengths_user.to_dict() ) sparse_jt_dict = _merge_list_of_jt_dict(sparse_jt_dict_list) seq_mulval_length_jt_dict = _merge_list_of_jt_dict( seq_mulval_length_jt_dict_list ) if self.has_dense: dense_t_dict = dense_feature.to_dict() seq_jt_dict: Dict[str, JaggedTensor] = {} for _, v in self._group_to_shared_sequence.items(): for info in v: if info.name in seq_jt_dict: continue jt = ( sparse_jt_dict[info.name] if info.is_sparse else sequence_dense_features[info.name] ) if info.is_sparse and info.value_dim != 1: length_jt = seq_mulval_length_jt_dict[info.raw_name] # length_jt.values is sequence key_lengths # length_jt.lengths is sequence seq_lengths jt_values = torch.segment_reduce( jt.values(), info.pooling, lengths=length_jt.values() ) if info.pooling == "mean": jt_values = torch.nan_to_num(jt_values, nan=0.0) jt = JaggedTensor( values=jt_values, lengths=length_jt.lengths(), ) seq_jt_dict[info.name] = jt if len(seq_jt_dict) > 0: jt_dict = _merge_list_of_jt_dict([sparse_jt_dict, seq_jt_dict]) else: jt_dict = sparse_jt_dict return jt_dict, dense_t_dict def forward( self, sparse_feature: KeyedJaggedTensor, dense_feature: KeyedTensor, sequence_dense_features: Dict[str, JaggedTensor], sequence_mulval_lengths: KeyedJaggedTensor, sparse_feature_user: KeyedJaggedTensor, sequence_mulval_lengths_user: KeyedJaggedTensor, tile_size: int = -1, ) -> Dict[str, torch.Tensor]: """Forward the module. Args: sparse_feature (KeyedJaggedTensor): sparse id feature. dense_feature (dense_feature): dense feature. sequence_dense_features (Dict[str, JaggedTensor]): dense sequence feature. sequence_mulval_lengths (KeyedJaggedTensor): key_lengths and seq_lengths for multi-value sparse sequence features, kjt.values is key_lengths, kjt.lengths is seq_lengths. sparse_feature_user (KeyedJaggedTensor): user-side sparse feature with batch_size=1, when use INPUT_TILE=3. sequence_mulval_lengths_user (KeyedJaggedTensor):key_lengths and seq_lengths of user-side multi-value sparse sequence features, with batch_size=1, when use INPUT_TILE=3. tile_size: size for user-side feature input tile. Returns: group_features (dict): dict of feature_group to embedded tensor. """ need_input_tile = is_input_tile() need_input_tile_emb = is_input_tile_emb() jt_dict, dense_t_dict = self._forward_impl( sparse_feature, dense_feature, sequence_dense_features, sequence_mulval_lengths, sparse_feature_user, sequence_mulval_lengths_user, ) results = {} for group_name, v in self._group_to_shared_query.items(): query_t_list = [] for info in v: if info.is_sparse: query_jt = jt_dict[info.name] if info.value_dim == 1: # for single-value id feature query_t = jt_dict[info.name].to_padded_dense(1).squeeze(1) else: # for multi-value id feature query_t = torch.segment_reduce( query_jt.values(), info.pooling, lengths=query_jt.lengths() ) if info.pooling == "mean": query_t = torch.nan_to_num(query_t, nan=0.0) if info.is_user and need_input_tile_emb: query_t = query_t.tile(tile_size, 1) else: query_t = dense_t_dict[info.name] query_t_list.append(query_t) if len(query_t_list) > 0: results[f"{group_name}.query"] = torch.cat(query_t_list, dim=1) for group_name, v in self._group_to_shared_sequence.items(): seq_t_list = [] group_sequence_length = 1 for i, info in enumerate(v): # when is_user is True # sequence_sparse_features # when input_tile_emb need to tile(tile_size,1): # sequence_dense_features always need to tile need_tile = False if info.is_user: if info.is_sparse: need_tile = need_input_tile_emb else: need_tile = need_input_tile jt = jt_dict[info.name] if i == 0: sequence_length = jt.lengths() group_sequence_length = fx_int_item(torch.max(sequence_length)) if need_tile: results[f"{group_name}.sequence_length"] = sequence_length.tile( tile_size ) else: results[f"{group_name}.sequence_length"] = sequence_length jt = jt.to_padded_dense(group_sequence_length) if need_tile: jt = jt.tile(tile_size, 1, 1) seq_t_list.append(jt) if seq_t_list: results[f"{group_name}.sequence"] = torch.cat(seq_t_list, dim=2) return results def jagged_forward( self, sparse_feature: KeyedJaggedTensor, dense_feature: KeyedTensor, sequence_dense_features: Dict[str, JaggedTensor], sequence_mulval_lengths: KeyedJaggedTensor, ) -> Dict[str, OrderedDict[str, JaggedTensor]]: """Forward the module. Args: sparse_feature (KeyedJaggedTensor): sparse id feature. dense_feature (dense_feature): dense feature. sequence_dense_features (Dict[str, JaggedTensor]): dense sequence feature. sequence_mulval_lengths (KeyedJaggedTensor): key_lengths and seq_lengths for multi-value sparse sequence features, kjt.values is key_lengths, kjt.lengths is seq_lengths. Returns: group_features (dict): dict of feature_group to dict of embedded tensor. """ need_input_tile = is_input_tile() assert not need_input_tile, "jagged forward not support INPUT_TILE now." jt_dict, dense_t_dict = self._forward_impl( sparse_feature, dense_feature, sequence_dense_features, sequence_mulval_lengths, EMPTY_KJT, EMPTY_KJT, ) results = {} for group_name, v in self._group_to_shared_feature.items(): group_result = OrderedDict() for info in v: if info.is_sparse or info.is_sequence: jt = jt_dict[info.name] else: jt = _dense_to_jt(dense_t_dict[info.name]) group_result[info.name] = jt results[group_name] = group_result return results