# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict, defaultdict
from typing import Dict, List, NamedTuple, Optional, Tuple, Union

import torch
from torch import nn
from torchrec.modules.embedding_configs import EmbeddingBagConfig, EmbeddingConfig
from torchrec.modules.embedding_modules import (
    EmbeddingBagCollection,
    EmbeddingCollection,
)
from torchrec.modules.mc_embedding_modules import (
    ManagedCollisionEmbeddingBagCollection,
    ManagedCollisionEmbeddingCollection,
)
from torchrec.modules.mc_modules import (
    ManagedCollisionCollection,
    ManagedCollisionModule,
    MCHManagedCollisionModule,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor

from tzrec.acc.utils import is_input_tile, is_input_tile_emb
from tzrec.datasets.utils import Batch
from tzrec.features.feature import BaseFeature
from tzrec.modules.dense_embedding_collection import (
    DenseEmbeddingCollection,
)
from tzrec.modules.sequence import create_seq_encoder
from tzrec.protos import model_pb2
from tzrec.protos.model_pb2 import FeatureGroupConfig, SeqGroupConfig
from tzrec.utils.fx_util import fx_int_item

EMPTY_KJT = KeyedJaggedTensor.empty()


torch.fx.wrap(fx_int_item)


@torch.fx.wrap
def _update_dict_tensor(
    dict1: Dict[str, torch.Tensor], dict2: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
    for key, value in dict2.items():
        if key in dict1:
            dict1[key] = torch.cat([dict1[key], value], dim=-1)
        else:
            dict1[key] = value
    return dict1


@torch.fx.wrap
def _merge_list_of_tensor_dict(
    list_of_tensor_dict: List[Dict[str, torch.Tensor]],
) -> Dict[str, torch.Tensor]:
    result = {}
    for tensor_dict in list_of_tensor_dict:
        result.update(tensor_dict)
    return result


@torch.fx.wrap
def _merge_list_of_dict_of_jt_dict(
    list_of_dict_of_jt_odict: List[Dict[str, Dict[str, JaggedTensor]]],
) -> Dict[str, Dict[str, JaggedTensor]]:
    result: Dict[str, Dict[str, JaggedTensor]] = {}
    for tensor_d_d in list_of_dict_of_jt_odict:
        result.update(tensor_d_d)
    return result


@torch.fx.wrap
def _merge_list_of_jt_dict(
    list_of_jt_dict: List[Dict[str, JaggedTensor]],
) -> Dict[str, JaggedTensor]:
    result: Dict[str, JaggedTensor] = {}
    for jt_dict in list_of_jt_dict:
        result.update(jt_dict)
    return result


@torch.fx.wrap
def _tile_and_combine_dense_kt(
    user_kt: Optional[KeyedTensor], item_kt: Optional[KeyedTensor], tile_size: int
) -> KeyedTensor:
    kt_keys: List[str] = []
    kt_length_per_key: List[int] = []
    kt_values: List[torch.Tensor] = []
    if user_kt is not None:
        kt_keys.extend(user_kt.keys())
        kt_length_per_key.extend(user_kt.length_per_key())
        kt_values.append(user_kt.values().tile(tile_size, 1))
    if item_kt is not None:
        kt_keys.extend(item_kt.keys())
        kt_length_per_key.extend(item_kt.length_per_key())
        kt_values.append(item_kt.values())
    return KeyedTensor(
        keys=kt_keys,
        length_per_key=kt_length_per_key,
        values=torch.cat(kt_values, dim=1),
    )


@torch.fx.wrap
def _dense_to_jt(t: torch.Tensor) -> JaggedTensor:
    return JaggedTensor(
        values=t,
        lengths=torch.ones(
            t.shape[0],
            dtype=torch.int64,
            device=t.device,
        ),
    )


class EmbeddingGroup(nn.Module):
    """Applies embedding lookup transformation for feature group.

    Args:
        features (list): list of features.
        feature_groups (list): list of feature group config.
        wide_embedding_dim (int, optional): wide group feature embedding dim.
        device (torch.device): embedding device, default is meta.
    """

    def __init__(
        self,
        features: List[BaseFeature],
        feature_groups: List[FeatureGroupConfig],
        wide_embedding_dim: Optional[int] = None,
        device: Optional[torch.device] = None,
    ) -> None:
        super().__init__()
        if device is None:
            device = torch.device("meta")
        self._features = features
        self._feature_groups = feature_groups
        self._name_to_feature = {x.name: x for x in features}
        self._name_to_feature_group = {x.group_name: x for x in feature_groups}

        self.emb_impls = nn.ModuleDict()
        self.seq_emb_impls = nn.ModuleDict()
        self.seq_encoders = nn.ModuleDict()

        self._impl_key_to_feat_groups = defaultdict(list)
        self._impl_key_to_seq_groups = defaultdict(list)
        self._group_name_to_impl_key = dict()
        self._group_name_to_seq_encoder_configs = defaultdict(list)
        self._grouped_features_keys = list()

        seq_group_names = []
        for feature_group in feature_groups:
            group_name = feature_group.group_name
            self._inspect_and_supplement_feature_group(feature_group, seq_group_names)
            # self._add_feature_group_sign_for_sequence_groups(feature_group)
            features_data_group = defaultdict(list)
            for feature_name in feature_group.feature_names:
                feature = self._name_to_feature[feature_name]
                features_data_group[feature.data_group].append(feature_name)
            for sequence_group in feature_group.sequence_groups:
                for feature_name in sequence_group.feature_names:
                    feature = self._name_to_feature[feature_name]
                    features_data_group[feature.data_group].append(feature_name)

            if len(features_data_group) > 1:
                error_info = [",".join(v) for v in features_data_group.values()]
                raise ValueError(
                    f"Feature {error_info} should not belong to same feature group."
                )
            impl_key = list(features_data_group.keys())[0]
            self._group_name_to_impl_key[group_name] = impl_key
            if feature_group.group_type == model_pb2.SEQUENCE:
                self._impl_key_to_seq_groups[impl_key].append(feature_group)
                self._grouped_features_keys.append(group_name + ".query")
                self._grouped_features_keys.append(group_name + ".sequence")
                self._grouped_features_keys.append(group_name + ".sequence_length")
            else:
                self._impl_key_to_feat_groups[impl_key].append(feature_group)
                if len(feature_group.sequence_groups) > 0:
                    self._impl_key_to_seq_groups[impl_key].extend(
                        list(feature_group.sequence_groups)
                    )
                if len(feature_group.sequence_encoders) > 0:
                    self._group_name_to_seq_encoder_configs[group_name] = list(
                        feature_group.sequence_encoders
                    )
                self._grouped_features_keys.append(group_name)

        for k, v in self._impl_key_to_feat_groups.items():
            self.emb_impls[k] = EmbeddingGroupImpl(
                features,
                feature_groups=v,
                wide_embedding_dim=wide_embedding_dim,
                device=device,
            )

        for k, v in self._impl_key_to_seq_groups.items():
            self.seq_emb_impls[k] = SequenceEmbeddingGroupImpl(
                features, feature_groups=v, device=device
            )
        self._group_name_to_seq_encoders = nn.ModuleDict()
        for (
            group_name,
            seq_encoder_configs,
        ) in self._group_name_to_seq_encoder_configs.items():
            impl_key = self._group_name_to_impl_key[group_name]
            seq_emb = self.seq_emb_impls[impl_key]
            group_seq_encoders = nn.ModuleList()
            for seq_encoder_config in seq_encoder_configs:
                seq_encoder = create_seq_encoder(
                    seq_encoder_config, seq_emb.all_group_total_dim()
                )
                group_seq_encoders.append(seq_encoder)
            self._group_name_to_seq_encoders[group_name] = group_seq_encoders

        self._group_feature_dims = OrderedDict()
        for feature_group in feature_groups:
            group_name = feature_group.group_name
            if feature_group.group_type != model_pb2.SEQUENCE:
                feature_dim = OrderedDict()
                impl_key = self._group_name_to_impl_key[group_name]
                feature_emb = self.emb_impls[impl_key]
                feature_dim.update(feature_emb.group_feature_dims(group_name))
                if group_name in self._group_name_to_seq_encoders:
                    seq_encoders = self._group_name_to_seq_encoders[group_name]
                    for i, seq_encoder in enumerate(seq_encoders):
                        feature_dim[f"{group_name}_seq_encoder_{i}"] = (
                            seq_encoder.output_dim()
                        )
                self._group_feature_dims[group_name] = feature_dim

        self._grouped_features_keys.sort()

    def grouped_features_keys(self) -> List[str]:
        """grouped_features_keys."""
        return self._grouped_features_keys

    def _inspect_and_supplement_feature_group(
        self, feature_group: FeatureGroupConfig, seq_group_names: List[str]
    ) -> None:
        """Inspect feature group sequence_groups and sequence_encoders."""
        group_name = feature_group.group_name
        sequence_groups = list(feature_group.sequence_groups)
        sequence_encoders = list(feature_group.sequence_encoders)
        is_deep = feature_group.group_type == model_pb2.DEEP
        if is_deep:
            if len(sequence_groups) == 0 and sequence_encoders == 0:
                return
            elif len(sequence_groups) > 0 and sequence_encoders == 0:
                raise ValueError(
                    f"{group_name} group has sequence_groups,but no sequence_encoders "
                )
            elif len(sequence_groups) == 0 and len(sequence_encoders) > 0:
                raise ValueError(
                    f"{group_name} group has sequence_encoders,but no sequence_groups "
                )

            if len(sequence_groups) > 1:
                for sequence_group in sequence_groups:
                    if not sequence_group.HasField("group_name"):
                        raise ValueError(
                            f"{group_name} has many sequence_groups, "
                            f"every sequence_group must has group_name"
                        )
            elif len(sequence_groups) == 1 and not sequence_groups[0].HasField(
                "group_name"
            ):
                sequence_groups[0].group_name = group_name

            for sequence_group in sequence_groups:
                if sequence_group.group_name in seq_group_names:
                    raise ValueError(
                        f"has repeat sequences groups_name: {sequence_group.group_name}"
                    )
                else:
                    seq_group_names.append(sequence_group.group_name)

            group_has_encoder = {
                sequence_group.group_name: False for sequence_group in sequence_groups
            }
            for sequence_encoder in sequence_encoders:
                seq_type = sequence_encoder.WhichOneof("seq_module")
                seq_config = getattr(sequence_encoder, seq_type)
                if not seq_config.HasField("input") and len(sequence_groups) == 1:
                    seq_config.input = sequence_groups[0].group_name
                if not seq_config.HasField("input"):
                    raise ValueError(
                        f"{group_name} group has multi sequence_groups, "
                        f"so sequence_encoders must has input"
                    )
                if seq_config.input not in group_has_encoder:
                    raise ValueError(
                        f"{group_name} sequence_encoder input {seq_config.input} "
                        f"not in sequence_groups"
                    )
                else:
                    group_has_encoder[seq_config.input] = True
            for k, v in group_has_encoder.items():
                if not v:
                    raise ValueError(
                        f"{group_name} sequence_groups {k} not has seq_encoder"
                    )
        else:
            if len(sequence_groups) > 0 or len(sequence_encoders) > 0:
                raise ValueError(
                    f"{group_name} group group_type is not DEEP, "
                    f"sequence_groups and sequence_encoders must configured in DEEP"
                )

    def group_names(self) -> List[str]:
        """Feature group names."""
        return list(self._name_to_feature_group.keys())

    def group_dims(self, group_name: str) -> List[int]:
        """Output dimension of each feature in a feature group.

        Args:
            group_name (str): feature group name, when group type is sequence,
                should use {group_name}.query or {group_name}.sequence.

        Return:
            group_dims (list): output dimension of each feature.
        """
        true_name = group_name.split(".")[0] if "." in group_name else group_name
        feature_group = self._name_to_feature_group[true_name]
        impl_key = self._group_name_to_impl_key[true_name]
        if feature_group.group_type == model_pb2.SEQUENCE:
            return self.seq_emb_impls[impl_key].group_dims(group_name)
        else:
            dims = self.emb_impls[impl_key].group_dims(group_name)
            if group_name in self._group_name_to_seq_encoders:
                for seq_encoder in self._group_name_to_seq_encoders[group_name]:
                    dims.append(seq_encoder.output_dim())
            return dims

    def group_total_dim(self, group_name: str) -> int:
        """Total output dimension of a feature group.

        Args:
            group_name (str): feature group name, when group type is sequence,
                should use {group_name}.query or {group_name}.sequence.

        Return:
            total_dim (int): total dimension of feature group.
        """
        true_name = group_name.split(".")[0] if "." in group_name else group_name
        feature_group = self._name_to_feature_group[true_name]
        impl_key = self._group_name_to_impl_key[true_name]
        if feature_group.group_type == model_pb2.SEQUENCE:
            return self.seq_emb_impls[impl_key].group_total_dim(group_name)
        else:
            return sum(self._group_feature_dims[group_name].values())

    def group_feature_dims(self, group_name: str) -> Dict[str, int]:
        """Every feature group each feature dim."""
        true_name = group_name.split(".")[0] if "." in group_name else group_name
        feature_group = self._name_to_feature_group[true_name]
        if feature_group.group_type == model_pb2.SEQUENCE:
            raise ValueError("not support sequence group")
        return self._group_feature_dims[group_name]

    def has_group(self, group_name: str) -> bool:
        """Check the feature group exist or not."""
        true_name = group_name.split(".")[0] if "." in group_name else group_name
        return true_name in self._name_to_feature_group.keys()

    def forward(
        self,
        batch: Batch,
    ) -> Dict[str, torch.Tensor]:
        """Forward the module.

        Args:
            batch (Batch): a instance of Batch with features.

        Returns:
            group_features (dict): dict of feature_group to embedded tensor.
        """
        result_dicts = []

        need_input_tile = is_input_tile()

        if need_input_tile:
            emb_keys = list(self.emb_impls.keys())
            seq_emb_keys = list(self.seq_emb_impls.keys())
            unique_keys = list(set(emb_keys + seq_emb_keys))
            # tile user dense feat & combine item dense feat, when has user dense feat
            for key in unique_keys:
                user_kt = batch.dense_features.get(key + "_user", None)
                if user_kt is not None:
                    item_kt = batch.dense_features.get(key, None)
                    batch.dense_features[key] = _tile_and_combine_dense_kt(
                        user_kt, item_kt, batch.tile_size
                    )

        for key, emb_impl in self.emb_impls.items():
            sparse_feat_kjt = None
            sparse_feat_kjt_user = None
            dense_feat_kt = None

            if emb_impl.has_dense:
                dense_feat_kt = batch.dense_features[key]
            if emb_impl.has_sparse or emb_impl.has_mc_sparse:
                sparse_feat_kjt = batch.sparse_features[key]
            if emb_impl.has_sparse_user or emb_impl.has_mc_sparse_user:
                sparse_feat_kjt_user = batch.sparse_features[key + "_user"]

            result_dicts.append(
                emb_impl(
                    sparse_feat_kjt,
                    dense_feat_kt,
                    sparse_feat_kjt_user,
                    batch.tile_size,
                )
            )

        for key, seq_emb_impl in self.seq_emb_impls.items():
            sparse_feat_kjt = None
            sparse_feat_kjt_user = None
            dense_feat_kt = None
            sequence_mulval_length_kjt = None
            sequence_mulval_length_kjt_user = None

            if seq_emb_impl.has_dense:
                dense_feat_kt = batch.dense_features[key]
            if seq_emb_impl.has_sparse or seq_emb_impl.has_mc_sparse:
                sparse_feat_kjt = batch.sparse_features[key]
            if seq_emb_impl.has_mulval_seq:
                sequence_mulval_length_kjt = batch.sequence_mulval_lengths[key]
            if seq_emb_impl.has_sparse_user or seq_emb_impl.has_mc_sparse_user:
                sparse_feat_kjt_user = batch.sparse_features[key + "_user"]
            if seq_emb_impl.has_mulval_seq_user:
                sequence_mulval_length_kjt_user = batch.sequence_mulval_lengths[
                    key + "_user"
                ]

            result_dicts.append(
                seq_emb_impl(
                    sparse_feat_kjt,
                    dense_feat_kt,
                    batch.sequence_dense_features,
                    sequence_mulval_length_kjt,
                    sparse_feat_kjt_user,
                    sequence_mulval_length_kjt_user,
                    batch.tile_size,
                )
            )

        result = _merge_list_of_tensor_dict(result_dicts)
        seq_feature_dict = {}
        for group_name, seq_encoders in self._group_name_to_seq_encoders.items():
            new_feature = []
            for seq_encoder in seq_encoders:
                new_feature.append(seq_encoder(result))
            seq_feature_dict[group_name] = torch.cat(new_feature, dim=-1)
        return _update_dict_tensor(result, seq_feature_dict)

    def predict(
        self,
        batch: Batch,
    ) -> List[torch.Tensor]:
        """Predict embedding module and return a list of grouped embedding features."""
        grouped_features = self.forward(batch)
        values_list = []
        for key in self._grouped_features_keys:
            values_list.append(grouped_features[key])
        return values_list


def _add_embedding_bag_config(
    emb_bag_configs: Dict[str, EmbeddingBagConfig], emb_bag_config: EmbeddingBagConfig
) -> None:
    """Add embedding bag config to a dict of embedding bag config.

    Args:
        emb_bag_configs(Dict[str, EmbeddingBagConfig]): a dict contains emb_bag_configs
        emb_bag_config(EmbeddingBagConfig): an instance of EmbeddingBagConfig
    """
    if emb_bag_config.name in emb_bag_configs:
        existed_emb_bag_config = emb_bag_configs[emb_bag_config.name]
        assert (
            emb_bag_config.num_embeddings == existed_emb_bag_config.num_embeddings
            and emb_bag_config.embedding_dim == existed_emb_bag_config.embedding_dim
            and emb_bag_config.pooling == existed_emb_bag_config.pooling
            and repr(emb_bag_config.init_fn) == repr(existed_emb_bag_config.init_fn)
        ), (
            f"there is a mismatch between {emb_bag_config} and "
            f"{existed_emb_bag_config}, can not share embedding."
        )
        for feature_name in emb_bag_config.feature_names:
            if feature_name not in existed_emb_bag_config.feature_names:
                existed_emb_bag_config.feature_names.append(feature_name)
    else:
        emb_bag_configs[emb_bag_config.name] = emb_bag_config


def _add_embedding_config(
    emb_configs: Dict[str, EmbeddingConfig], emb_config: EmbeddingConfig
) -> None:
    """Add embedding config to a dict of embedding config.

    Args:
        emb_configs(Dict[str, EmbeddingConfig]): a dict contains emb_configs
        emb_config(EmbeddingConfig): an instance of EmbeddingConfig
    """
    if emb_config.name in emb_configs:
        existed_emb_config = emb_configs[emb_config.name]
        assert (
            emb_config.num_embeddings == existed_emb_config.num_embeddings
            and emb_config.embedding_dim == existed_emb_config.embedding_dim
            and repr(emb_config.init_fn) == repr(existed_emb_config.init_fn)
        ), (
            f"there is a mismatch between {emb_config} and "
            f"{existed_emb_config}, can not share embedding."
        )
        for feature_name in emb_config.feature_names:
            if feature_name not in existed_emb_config.feature_names:
                existed_emb_config.feature_names.append(feature_name)
    else:
        emb_configs[emb_config.name] = emb_config


def _add_mc_module(
    mc_modules: Dict[str, ManagedCollisionModule],
    emb_name: str,
    mc_module: ManagedCollisionModule,
) -> None:
    """Add ManagedCollisionModule to a dict of ManagedCollisionModule.

    Args:
        mc_modules(Dict[str, ManagedCollisionModule]): a dict of ManagedCollisionModule.
        emb_name(str): embedding_name.
        mc_module(ManagedCollisionModule): an instance of ManagedCollisionModule.
    """
    if emb_name in mc_modules:
        existed_mc_module = mc_modules[emb_name]
        if isinstance(mc_module, MCHManagedCollisionModule):
            assert isinstance(existed_mc_module, MCHManagedCollisionModule)
            assert mc_module._zch_size == existed_mc_module._zch_size
            assert mc_module._eviction_interval == existed_mc_module._eviction_interval
            assert repr(mc_module._eviction_policy) == repr(mc_module._eviction_policy)
    mc_modules[emb_name] = mc_module


class EmbeddingGroupImpl(nn.Module):
    """Applies embedding lookup transformation for feature group.

    Args:
        features (list): list of features.
        feature_groups (list): list of feature group config.
        wide_embedding_dim (int, optional): wide group feature embedding dim.
        device (torch.device): embedding device, default is meta.
    """

    def __init__(
        self,
        features: List[BaseFeature],
        feature_groups: List[FeatureGroupConfig],
        wide_embedding_dim: Optional[int] = None,
        device: Optional[torch.device] = None,
    ) -> None:
        super().__init__()
        if device is None:
            device = torch.device("meta")
        name_to_feature = {x.name: x for x in features}

        need_input_tile_emb = is_input_tile_emb()

        emb_bag_configs = OrderedDict()
        mc_emb_bag_configs = OrderedDict()
        mc_modules = OrderedDict()
        dense_embedding_configs = []
        self.has_sparse = False
        self.has_mc_sparse = False
        self.has_dense = False
        self.has_dense_embedding = False

        # for sparse input-tile-emb
        emb_bag_configs_user = OrderedDict()
        mc_emb_bag_configs_user = OrderedDict()
        mc_modules_user = OrderedDict()
        self.has_sparse_user = False
        self.has_mc_sparse_user = False

        self._group_to_feature_names = OrderedDict()
        self._group_to_shared_feature_names = OrderedDict()
        self._group_total_dim = dict()
        self._group_feature_output_dims = dict()
        self._group_dense_feature_names = dict()
        self._group_dense_embedding_feature_names = dict()

        feat_to_group_to_emb_name = defaultdict(dict)
        for feature_group in feature_groups:
            group_name = feature_group.group_name
            for feature_name in feature_group.feature_names:
                feature = name_to_feature[feature_name]
                if feature.is_sparse:
                    emb_bag_config = feature.emb_bag_config
                    # pyre-ignore [16]
                    emb_name = emb_bag_config.name
                    if feature_group.group_type == model_pb2.WIDE:
                        emb_name = emb_name + "_wide"
                    feat_to_group_to_emb_name[feature_name][group_name] = emb_name

        shared_feature_flag = dict()
        for feature_name, group_to_emb_name in feat_to_group_to_emb_name.items():
            if len(set(group_to_emb_name.values())) > 1:
                shared_feature_flag[feature_name] = True
            else:
                shared_feature_flag[feature_name] = False

        non_emb_dense_feature_to_dim = OrderedDict()
        for feature_group in feature_groups:
            total_dim = 0
            feature_output_dims = OrderedDict()
            group_name = feature_group.group_name
            feature_names = list(feature_group.feature_names)
            shared_feature_names = []
            is_wide = feature_group.group_type == model_pb2.WIDE
            for name in feature_names:
                shared_name = name
                feature = name_to_feature[name]

                if feature.is_sparse:
                    output_dim = feature.output_dim
                    emb_bag_config = feature.emb_bag_config
                    mc_module = feature.mc_module(device)
                    assert emb_bag_config is not None
                    if is_wide:
                        # TODO(hongsheng.jhs): change to embedding_dim to 1
                        # when fbgemm support embedding_dim=1
                        emb_bag_config.embedding_dim = output_dim = (
                            wide_embedding_dim or 4
                        )
                    # we may modify ebc name at feat_to_group_to_emb_name, e.g., wide
                    emb_bag_config.name = feat_to_group_to_emb_name[name][group_name]

                    if need_input_tile_emb and feature.is_user_feat:
                        _add_embedding_bag_config(
                            emb_bag_configs=mc_emb_bag_configs_user
                            if mc_module
                            else emb_bag_configs_user,
                            emb_bag_config=emb_bag_config,
                        )
                        if mc_module:
                            _add_mc_module(
                                mc_modules_user, emb_bag_config.name, mc_module
                            )
                            self.has_mc_sparse_user = True
                        else:
                            self.has_sparse_user = True
                    else:
                        _add_embedding_bag_config(
                            emb_bag_configs=mc_emb_bag_configs
                            if mc_module
                            else emb_bag_configs,
                            emb_bag_config=emb_bag_config,
                        )
                        if mc_module:
                            _add_mc_module(mc_modules, emb_bag_config.name, mc_module)
                            self.has_mc_sparse = True
                        else:
                            self.has_sparse = True

                    if shared_feature_flag[name]:
                        shared_name = shared_name + "@" + emb_bag_config.name
                else:
                    output_dim = feature.output_dim

                    if is_wide:
                        raise ValueError(
                            f"dense feature [{name}] should not be configured in "
                            "wide group."
                        )
                    else:
                        self.has_dense = True
                        if feature.dense_emb_config:
                            self.has_dense_embedding = True
                            conf_obj = feature.dense_emb_config
                            dense_embedding_configs.append(conf_obj)
                        else:
                            non_emb_dense_feature_to_dim[name] = output_dim

                total_dim += output_dim
                feature_output_dims[name] = output_dim
                shared_feature_names.append(shared_name)
            self._group_to_feature_names[group_name] = feature_names
            if len(shared_feature_names) > 0:
                self._group_to_shared_feature_names[group_name] = shared_feature_names
            self._group_total_dim[group_name] = total_dim
            self._group_feature_output_dims[group_name] = feature_output_dims

        self.ebc = EmbeddingBagCollection(list(emb_bag_configs.values()), device=device)
        if self.has_mc_sparse:
            self.mc_ebc = ManagedCollisionEmbeddingBagCollection(
                EmbeddingBagCollection(
                    list(mc_emb_bag_configs.values()), device=device
                ),
                ManagedCollisionCollection(
                    mc_modules, list(mc_emb_bag_configs.values())
                ),
            )
        if self.has_dense_embedding:
            self.dense_ec = DenseEmbeddingCollection(
                dense_embedding_configs,
                device=device,
                raw_dense_feature_to_dim=non_emb_dense_feature_to_dim,
            )

        if need_input_tile_emb:
            self.ebc_user = EmbeddingBagCollection(
                list(emb_bag_configs_user.values()), device=device
            )
            if self.has_mc_sparse_user:
                self.mc_ebc_user = ManagedCollisionEmbeddingBagCollection(
                    EmbeddingBagCollection(
                        list(mc_emb_bag_configs_user.values()), device=device
                    ),
                    ManagedCollisionCollection(
                        mc_modules_user, list(mc_emb_bag_configs_user.values())
                    ),
                )

    def group_dims(self, group_name: str) -> List[int]:
        """Output dimension of each feature in a feature group."""
        return list(self._group_feature_output_dims[group_name].values())

    def group_feature_dims(self, group_name: str) -> Dict[str, int]:
        """Output dimension of each feature in a feature group."""
        return self._group_feature_output_dims[group_name]

    def group_total_dim(self, group_name: str) -> int:
        """Total output dimension of a feature group."""
        return self._group_total_dim[group_name]

    def forward(
        self,
        sparse_feature: KeyedJaggedTensor,
        dense_feature: KeyedTensor,
        sparse_feature_user: KeyedJaggedTensor,
        tile_size: int = -1,
    ) -> Dict[str, torch.Tensor]:
        """Forward the module.

        Args:
            sparse_feature (KeyedJaggedTensor): sparse id feature.
            dense_feature (dense_feature): dense feature.
            sparse_feature_user (KeyedJaggedTensor): user-side sparse feature
                with batch_size=1, when use INPUT_TILE=3.
            tile_size: size for user-side feature input tile.

        Returns:
            group_features (dict): dict of feature_group to embedded tensor.
        """
        kts: List[KeyedTensor] = []
        if self.has_sparse:
            kts.append(self.ebc(sparse_feature))

        if self.has_mc_sparse:
            kts.append(self.mc_ebc(sparse_feature)[0])

        # do user-side embedding input-tile
        if self.has_sparse_user:
            keyed_tensor_user = self.ebc_user(sparse_feature_user)
            values_tile = keyed_tensor_user.values().tile(tile_size, 1)
            keyed_tensor_user_tile = KeyedTensor(
                keys=keyed_tensor_user.keys(),
                length_per_key=keyed_tensor_user.length_per_key(),
                values=values_tile,
            )
            kts.append(keyed_tensor_user_tile)

        # do user-side mc embedding input-tile
        if self.has_mc_sparse_user:
            keyed_tensor_user = self.mc_ebc_user(sparse_feature_user)[0]
            values_tile = keyed_tensor_user.values().tile(tile_size, 1)
            keyed_tensor_user_tile = KeyedTensor(
                keys=keyed_tensor_user.keys(),
                length_per_key=keyed_tensor_user.length_per_key(),
                values=values_tile,
            )
            kts.append(keyed_tensor_user_tile)

        if self.has_dense:
            if self.has_dense_embedding:
                kts.append(self.dense_ec(dense_feature))
            else:
                kts.append(dense_feature)

        group_tensors = KeyedTensor.regroup_as_dict(
            kts,
            list(self._group_to_shared_feature_names.values()),
            list(self._group_to_shared_feature_names.keys()),
        )

        return group_tensors


class SequenceEmbeddingGroup(nn.Module):
    """Applies embedding lookup transformation for feature group.

    Args:
        features (list): list of features.
        feature_groups (list): list of feature group config.
        wide_embedding_dim (int, optional): wide group feature embedding dim.
        device (torch.device): embedding device, default is meta.
    """

    def __init__(
        self,
        features: List[BaseFeature],
        feature_groups: List[FeatureGroupConfig],
        device: Optional[torch.device] = None,
    ) -> None:
        super().__init__()
        if device is None:
            device = torch.device("meta")
        self._features = features
        self._feature_groups = feature_groups
        self._name_to_feature = {x.name: x for x in features}
        self._name_to_feature_group = {x.group_name: x for x in feature_groups}

        self.seq_emb_impls = nn.ModuleDict()

        self._impl_key_to_seq_groups = defaultdict(list)
        self._group_name_to_impl_key = dict()

        for feature_group in feature_groups:
            assert feature_group.group_type == model_pb2.SEQUENCE

            group_name = feature_group.group_name
            features_data_group = defaultdict(list)
            for feature_name in feature_group.feature_names:
                feature = self._name_to_feature[feature_name]
                features_data_group[feature.data_group].append(feature_name)

            if len(features_data_group) > 1:
                error_info = [",".join(v) for v in features_data_group.values()]
                raise ValueError(
                    f"Feature {error_info} should not belong to same feature group."
                )

            impl_key = list(features_data_group.keys())[0]
            self._group_name_to_impl_key[group_name] = impl_key
            self._impl_key_to_seq_groups[impl_key].append(feature_group)

        for k, v in self._impl_key_to_seq_groups.items():
            self.seq_emb_impls[k] = SequenceEmbeddingGroupImpl(
                features, feature_groups=v, device=device
            )

    def group_names(self) -> List[str]:
        """Feature group names."""
        return list(self._name_to_feature_group.keys())

    def group_dims(self, group_name: str) -> List[int]:
        """Output dimension of each feature in a feature group.

        Args:
            group_name (str): feature group name, when group type is sequence,
                should use {group_name}.query or {group_name}.sequence.

        Return:
            group_dims (list): output dimension of each feature.
        """
        true_name = group_name.split(".")[0] if "." in group_name else group_name
        impl_key = self._group_name_to_impl_key[true_name]
        return self.seq_emb_impls[impl_key].group_dims(group_name)

    def group_total_dim(self, group_name: str) -> int:
        """Total output dimension of a feature group.

        Args:
            group_name (str): feature group name, when group type is sequence,
                should use {group_name}.query or {group_name}.sequence.

        Return:
            total_dim (int): total dimension of feature group.
        """
        true_name = group_name.split(".")[0] if "." in group_name else group_name
        impl_key = self._group_name_to_impl_key[true_name]
        return self.seq_emb_impls[impl_key].group_total_dim(group_name)

    def has_group(self, group_name: str) -> bool:
        """Check the feature group exist or not."""
        true_name = group_name.split(".")[0] if "." in group_name else group_name
        return true_name in self._name_to_feature_group.keys()

    def forward(
        self,
        batch: Batch,
    ) -> Dict[str, torch.Tensor]:
        """Forward the module.

        Args:
            batch (Batch): a instance of Batch with features.

        Returns:
            group_features (dict): dict of feature_group to embedded tensor.
        """
        result_dicts = []

        need_input_tile = is_input_tile()

        if need_input_tile:
            unique_keys = list(self.seq_emb_impls.keys())
            # tile user dense feat & combine item dense feat, when has user dense feat
            for key in unique_keys:
                user_kt = batch.dense_features.get(key + "_user", None)
                if user_kt is not None:
                    item_kt = batch.dense_features.get(key, None)
                    batch.dense_features[key] = _tile_and_combine_dense_kt(
                        user_kt, item_kt, batch.tile_size
                    )

        for key, seq_emb_impl in self.seq_emb_impls.items():
            sparse_feat_kjt = None
            sparse_feat_kjt_user = None
            dense_feat_kt = None
            sequence_mulval_length_kjt = None
            sequence_mulval_length_kjt_user = None

            if seq_emb_impl.has_dense:
                dense_feat_kt = batch.dense_features[key]
            if seq_emb_impl.has_sparse or seq_emb_impl.has_mc_sparse:
                sparse_feat_kjt = batch.sparse_features[key]
            if seq_emb_impl.has_mulval_seq:
                sequence_mulval_length_kjt = batch.sequence_mulval_lengths[key]
            if seq_emb_impl.has_sparse_user or seq_emb_impl.has_mc_sparse_user:
                sparse_feat_kjt_user = batch.sparse_features[key + "_user"]
            if seq_emb_impl.has_mulval_seq_user:
                sequence_mulval_length_kjt_user = batch.sequence_mulval_lengths[
                    key + "_user"
                ]

            result_dicts.append(
                seq_emb_impl(
                    sparse_feat_kjt,
                    dense_feat_kt,
                    batch.sequence_dense_features,
                    sequence_mulval_length_kjt,
                    sparse_feat_kjt_user,
                    sequence_mulval_length_kjt_user,
                    batch.tile_size,
                )
            )

        return _merge_list_of_tensor_dict(result_dicts)

    def jagged_forward(
        self,
        batch: Batch,
    ) -> Dict[str, OrderedDict[str, JaggedTensor]]:
        """Forward the module.

        Args:
            batch (Batch): a instance of Batch with features.

        Returns:
            group_features (dict): dict of feature_group to dict of embedded tensor.
        """
        need_input_tile = is_input_tile()
        assert not need_input_tile, "jagged forward not support INPUT_TILE now."

        result_dicts = []
        for key, seq_emb_impl in self.seq_emb_impls.items():
            sparse_feat_kjt = None
            dense_feat_kt = None
            sequence_mulval_length_kjt = None

            if seq_emb_impl.has_dense:
                dense_feat_kt = batch.dense_features[key]
            if seq_emb_impl.has_sparse or seq_emb_impl.has_mc_sparse:
                sparse_feat_kjt = batch.sparse_features[key]
            if seq_emb_impl.has_mulval_seq:
                sequence_mulval_length_kjt = batch.sequence_mulval_lengths[key]

            result_dicts.append(
                seq_emb_impl.jagged_forward(
                    sparse_feat_kjt,
                    dense_feat_kt,
                    batch.sequence_dense_features,
                    sequence_mulval_length_kjt,
                )
            )

        result = _merge_list_of_dict_of_jt_dict(result_dicts)
        return result


class _SequenceEmbeddingInfo(NamedTuple):
    """One Embedding info in SequenceGroup."""

    name: str
    raw_name: str
    is_sparse: bool
    pooling: str
    value_dim: int
    is_user: bool
    is_sequence: bool


class SequenceEmbeddingGroupImpl(nn.Module):
    """Applies embedding lookup transformation for sequence feature group.

    Args:
        features (list): list of features.
        feature_groups (list): list of feature group config or seq group config.
        device (torch.device): embedding device, default is meta.
    """

    def __init__(
        self,
        features: List[BaseFeature],
        feature_groups: List[Union[FeatureGroupConfig, SeqGroupConfig]],
        device: Optional[torch.device] = None,
    ) -> None:
        super().__init__()
        if device is None:
            device = torch.device("meta")
        name_to_feature = {x.name: x for x in features}

        need_input_tile = is_input_tile()
        need_input_tile_emb = is_input_tile_emb()

        dim_to_emb_configs = defaultdict(OrderedDict)
        dim_to_mc_emb_configs = defaultdict(OrderedDict)
        dim_to_mc_modules = defaultdict(OrderedDict)
        self.has_sparse = False
        self.has_mc_sparse = False
        self.has_dense = False
        self.has_sequence_dense = False
        self.has_mulval_seq = False

        # for sparse input-tile-emb
        dim_to_emb_configs_user = defaultdict(OrderedDict)
        dim_to_mc_emb_configs_user = defaultdict(OrderedDict)
        dim_to_mc_modules_user = defaultdict(OrderedDict)
        self.has_sparse_user = False
        self.has_mc_sparse_user = False
        self.has_mulval_seq_user = False

        self._group_to_shared_query = OrderedDict()
        self._group_to_shared_sequence = OrderedDict()
        self._group_to_shared_feature = OrderedDict()
        self._group_total_dim = dict()
        self._group_output_dims = dict()

        feat_to_group_to_emb_name = defaultdict(dict)
        for feature_group in feature_groups:
            group_name = feature_group.group_name
            for feature_name in feature_group.feature_names:
                feature = name_to_feature[feature_name]
                if feature.is_sparse:
                    emb_config = feature.emb_config
                    # pyre-ignore [16]
                    emb_name = emb_config.name
                    feat_to_group_to_emb_name[feature_name][group_name] = emb_name

        shared_feature_flag = dict()
        for feature_name, group_to_emb_name in feat_to_group_to_emb_name.items():
            if len(set(group_to_emb_name.values())) > 1:
                shared_feature_flag[feature_name] = True
            else:
                shared_feature_flag[feature_name] = False

        for feature_group in feature_groups:
            query_dim = 0
            sequence_dim = 0
            query_dims = []
            sequence_dims = []
            output_dims = []
            group_name = feature_group.group_name
            feature_names = list(feature_group.feature_names)
            shared_query = []
            shared_sequence = []
            shared_feature = []

            for name in feature_names:
                shared_name = name
                feature = name_to_feature[name]
                if feature.is_sparse:
                    output_dim = feature.output_dim
                    emb_config = feature.emb_config
                    mc_module = feature.mc_module(device)
                    assert emb_config is not None
                    # we may/could modify ec name at feat_to_group_to_emb_name
                    emb_config.name = feat_to_group_to_emb_name[name][group_name]
                    embedding_dim = emb_config.embedding_dim

                    if need_input_tile_emb and feature.is_user_feat:
                        emb_configs = (
                            dim_to_mc_emb_configs_user[embedding_dim]
                            if mc_module
                            else dim_to_emb_configs_user[embedding_dim]
                        )
                        _add_embedding_config(
                            emb_configs=emb_configs,
                            emb_config=emb_config,
                        )
                        if mc_module:
                            _add_mc_module(
                                dim_to_mc_modules_user[embedding_dim],
                                emb_config.name,
                                mc_module,
                            )
                            self.has_mc_sparse_user = True
                        else:
                            self.has_sparse_user = True
                        if feature.is_sequence and feature.value_dim != 1:
                            self.has_mulval_seq_user = True
                    else:
                        emb_configs = (
                            dim_to_mc_emb_configs[embedding_dim]
                            if mc_module
                            else dim_to_emb_configs[embedding_dim]
                        )
                        _add_embedding_config(
                            emb_configs=emb_configs,
                            emb_config=emb_config,
                        )
                        if mc_module:
                            _add_mc_module(
                                dim_to_mc_modules[embedding_dim],
                                emb_config.name,
                                mc_module,
                            )
                            self.has_mc_sparse = True
                        else:
                            self.has_sparse = True
                        if feature.is_sequence and feature.value_dim != 1:
                            self.has_mulval_seq = True

                    if shared_feature_flag[name]:
                        shared_name = shared_name + "@" + emb_config.name
                else:
                    output_dim = feature.output_dim
                    if feature.is_sequence:
                        self.has_sequence_dense = True
                    else:
                        self.has_dense = True

                is_user_feat = feature.is_user_feat if need_input_tile else False
                shared_info = _SequenceEmbeddingInfo(
                    name=shared_name,
                    raw_name=name,
                    is_sparse=feature.is_sparse,
                    pooling=feature.pooling_type.value.lower(),
                    value_dim=feature.value_dim,
                    is_user=is_user_feat,
                    is_sequence=feature.is_sequence,
                )
                if feature.is_sequence:
                    shared_sequence.append(shared_info)
                    sequence_dim += output_dim
                    sequence_dims.append(output_dim)
                else:
                    shared_query.append(shared_info)
                    query_dim += output_dim
                    query_dims.append(output_dim)
                shared_feature.append(shared_info)
                output_dims.append(output_dim)

            self._group_to_shared_query[group_name] = shared_query
            self._group_to_shared_sequence[group_name] = shared_sequence
            self._group_to_shared_feature[group_name] = shared_feature
            self._group_total_dim[f"{group_name}.query"] = query_dim
            self._group_total_dim[f"{group_name}.sequence"] = sequence_dim
            self._group_output_dims[f"{group_name}.query"] = query_dims
            self._group_output_dims[f"{group_name}.sequence"] = sequence_dims
            self._group_output_dims[group_name] = output_dims

        self.ec_list = nn.ModuleList()
        for _, emb_configs in dim_to_emb_configs.items():
            self.ec_list.append(
                EmbeddingCollection(list(emb_configs.values()), device=device)
            )

        self.mc_ec_list = nn.ModuleList()
        for k, emb_configs in dim_to_mc_emb_configs.items():
            self.mc_ec_list.append(
                ManagedCollisionEmbeddingCollection(
                    EmbeddingCollection(list(emb_configs.values()), device=device),
                    ManagedCollisionCollection(
                        dim_to_mc_modules[k], list(emb_configs.values())
                    ),
                )
            )
        if need_input_tile_emb:
            self.ec_list_user = nn.ModuleList()
            for _, emb_configs in dim_to_emb_configs_user.items():
                self.ec_list_user.append(
                    EmbeddingCollection(list(emb_configs.values()), device=device)
                )
            self.mc_ec_list_user = nn.ModuleList()
            for k, emb_configs in dim_to_mc_emb_configs_user.items():
                self.mc_ec_list_user.append(
                    ManagedCollisionEmbeddingCollection(
                        EmbeddingCollection(list(emb_configs.values()), device=device),
                        ManagedCollisionCollection(
                            dim_to_mc_modules_user[k], list(emb_configs.values())
                        ),
                    )
                )

    def group_dims(self, group_name: str) -> List[int]:
        """Output dimension of each feature in a feature group."""
        return self._group_output_dims[group_name]

    def group_total_dim(self, group_name: str) -> int:
        """Total output dimension of a feature group."""
        if "." in group_name:
            return self._group_total_dim[group_name]
        else:
            return (
                self._group_total_dim[f"{group_name}.query"]
                + self._group_total_dim[f"{group_name}.sequence"]
            )

    def all_group_total_dim(self) -> Dict[str, int]:
        """Total output dimension of all feature group."""
        return self._group_total_dim

    def has_group(self, group_name: str) -> bool:
        """Check the feature group exist or not."""
        true_name = group_name.split(".")[0] if "." in group_name else group_name
        return true_name in self._group_output_dims.keys()

    def _forward_impl(
        self,
        sparse_feature: KeyedJaggedTensor,
        dense_feature: KeyedTensor,
        sequence_dense_features: Dict[str, JaggedTensor],
        sequence_mulval_lengths: KeyedJaggedTensor,
        sparse_feature_user: KeyedJaggedTensor,
        sequence_mulval_lengths_user: KeyedJaggedTensor,
    ) -> Tuple[Dict[str, JaggedTensor], Dict[str, torch.Tensor]]:
        sparse_jt_dict_list: List[Dict[str, JaggedTensor]] = []
        seq_mulval_length_jt_dict_list: List[Dict[str, JaggedTensor]] = []
        dense_t_dict: Dict[str, torch.Tensor] = {}

        if self.has_sparse:
            for ec in self.ec_list:
                sparse_jt_dict_list.append(ec(sparse_feature))

        if self.has_mc_sparse:
            for ec in self.mc_ec_list:
                sparse_jt_dict_list.append(ec(sparse_feature)[0])

        if self.has_mulval_seq:
            seq_mulval_length_jt_dict_list.append(sequence_mulval_lengths.to_dict())

        if self.has_sparse_user:
            for ec in self.ec_list_user:
                sparse_jt_dict_list.append(ec(sparse_feature_user))

        if self.has_mc_sparse_user:
            for ec in self.mc_ec_list_user:
                sparse_jt_dict_list.append(ec(sparse_feature_user)[0])

        if self.has_mulval_seq_user:
            seq_mulval_length_jt_dict_list.append(
                sequence_mulval_lengths_user.to_dict()
            )

        sparse_jt_dict = _merge_list_of_jt_dict(sparse_jt_dict_list)
        seq_mulval_length_jt_dict = _merge_list_of_jt_dict(
            seq_mulval_length_jt_dict_list
        )

        if self.has_dense:
            dense_t_dict = dense_feature.to_dict()

        seq_jt_dict: Dict[str, JaggedTensor] = {}
        for _, v in self._group_to_shared_sequence.items():
            for info in v:
                if info.name in seq_jt_dict:
                    continue
                jt = (
                    sparse_jt_dict[info.name]
                    if info.is_sparse
                    else sequence_dense_features[info.name]
                )
                if info.is_sparse and info.value_dim != 1:
                    length_jt = seq_mulval_length_jt_dict[info.raw_name]
                    # length_jt.values is sequence key_lengths
                    # length_jt.lengths is sequence seq_lengths
                    jt_values = torch.segment_reduce(
                        jt.values(), info.pooling, lengths=length_jt.values()
                    )
                    if info.pooling == "mean":
                        jt_values = torch.nan_to_num(jt_values, nan=0.0)
                    jt = JaggedTensor(
                        values=jt_values,
                        lengths=length_jt.lengths(),
                    )
                seq_jt_dict[info.name] = jt

        if len(seq_jt_dict) > 0:
            jt_dict = _merge_list_of_jt_dict([sparse_jt_dict, seq_jt_dict])
        else:
            jt_dict = sparse_jt_dict

        return jt_dict, dense_t_dict

    def forward(
        self,
        sparse_feature: KeyedJaggedTensor,
        dense_feature: KeyedTensor,
        sequence_dense_features: Dict[str, JaggedTensor],
        sequence_mulval_lengths: KeyedJaggedTensor,
        sparse_feature_user: KeyedJaggedTensor,
        sequence_mulval_lengths_user: KeyedJaggedTensor,
        tile_size: int = -1,
    ) -> Dict[str, torch.Tensor]:
        """Forward the module.

        Args:
            sparse_feature (KeyedJaggedTensor): sparse id feature.
            dense_feature (dense_feature): dense feature.
            sequence_dense_features (Dict[str, JaggedTensor]): dense sequence feature.
            sequence_mulval_lengths (KeyedJaggedTensor): key_lengths and seq_lengths for
                multi-value sparse sequence features, kjt.values is key_lengths,
                kjt.lengths is seq_lengths.
            sparse_feature_user (KeyedJaggedTensor): user-side sparse feature
                with batch_size=1, when use INPUT_TILE=3.
            sequence_mulval_lengths_user (KeyedJaggedTensor):key_lengths and seq_lengths
                of user-side multi-value sparse sequence features, with batch_size=1,
                when use INPUT_TILE=3.
            tile_size: size for user-side feature input tile.

        Returns:
            group_features (dict): dict of feature_group to embedded tensor.
        """
        need_input_tile = is_input_tile()
        need_input_tile_emb = is_input_tile_emb()

        jt_dict, dense_t_dict = self._forward_impl(
            sparse_feature,
            dense_feature,
            sequence_dense_features,
            sequence_mulval_lengths,
            sparse_feature_user,
            sequence_mulval_lengths_user,
        )

        results = {}
        for group_name, v in self._group_to_shared_query.items():
            query_t_list = []
            for info in v:
                if info.is_sparse:
                    query_jt = jt_dict[info.name]
                    if info.value_dim == 1:
                        # for single-value id feature
                        query_t = jt_dict[info.name].to_padded_dense(1).squeeze(1)
                    else:
                        # for multi-value id feature
                        query_t = torch.segment_reduce(
                            query_jt.values(), info.pooling, lengths=query_jt.lengths()
                        )
                        if info.pooling == "mean":
                            query_t = torch.nan_to_num(query_t, nan=0.0)
                    if info.is_user and need_input_tile_emb:
                        query_t = query_t.tile(tile_size, 1)
                else:
                    query_t = dense_t_dict[info.name]
                query_t_list.append(query_t)
            if len(query_t_list) > 0:
                results[f"{group_name}.query"] = torch.cat(query_t_list, dim=1)

        for group_name, v in self._group_to_shared_sequence.items():
            seq_t_list = []

            group_sequence_length = 1
            for i, info in enumerate(v):
                # when is_user is True
                #   sequence_sparse_features
                #       when input_tile_emb need to tile(tile_size,1):
                #   sequence_dense_features always need to tile
                need_tile = False
                if info.is_user:
                    if info.is_sparse:
                        need_tile = need_input_tile_emb
                    else:
                        need_tile = need_input_tile
                jt = jt_dict[info.name]
                if i == 0:
                    sequence_length = jt.lengths()
                    group_sequence_length = fx_int_item(torch.max(sequence_length))

                    if need_tile:
                        results[f"{group_name}.sequence_length"] = sequence_length.tile(
                            tile_size
                        )
                    else:
                        results[f"{group_name}.sequence_length"] = sequence_length

                jt = jt.to_padded_dense(group_sequence_length)

                if need_tile:
                    jt = jt.tile(tile_size, 1, 1)
                seq_t_list.append(jt)

            if seq_t_list:
                results[f"{group_name}.sequence"] = torch.cat(seq_t_list, dim=2)

        return results

    def jagged_forward(
        self,
        sparse_feature: KeyedJaggedTensor,
        dense_feature: KeyedTensor,
        sequence_dense_features: Dict[str, JaggedTensor],
        sequence_mulval_lengths: KeyedJaggedTensor,
    ) -> Dict[str, OrderedDict[str, JaggedTensor]]:
        """Forward the module.

        Args:
            sparse_feature (KeyedJaggedTensor): sparse id feature.
            dense_feature (dense_feature): dense feature.
            sequence_dense_features (Dict[str, JaggedTensor]): dense sequence feature.
            sequence_mulval_lengths (KeyedJaggedTensor): key_lengths and seq_lengths for
                multi-value sparse sequence features, kjt.values is key_lengths,
                kjt.lengths is seq_lengths.

        Returns:
            group_features (dict): dict of feature_group to dict of embedded tensor.
        """
        need_input_tile = is_input_tile()
        assert not need_input_tile, "jagged forward not support INPUT_TILE now."

        jt_dict, dense_t_dict = self._forward_impl(
            sparse_feature,
            dense_feature,
            sequence_dense_features,
            sequence_mulval_lengths,
            EMPTY_KJT,
            EMPTY_KJT,
        )

        results = {}
        for group_name, v in self._group_to_shared_feature.items():
            group_result = OrderedDict()
            for info in v:
                if info.is_sparse or info.is_sequence:
                    jt = jt_dict[info.name]
                else:
                    jt = _dense_to_jt(dense_t_dict[info.name])
                group_result[info.name] = jt
            results[group_name] = group_result

        return results
