# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import os
import shutil
from collections import OrderedDict
from copy import copy
from functools import partial  # NOQA
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import pyarrow as pa
import pyfg
import torch
from torch import nn  # NOQA
from torchrec.modules.embedding_configs import (
    EmbeddingBagConfig,
    EmbeddingConfig,
    PoolingType,
)
from torchrec.modules.mc_modules import (
    DistanceLFU_EvictionPolicy,
    LFU_EvictionPolicy,
    LRU_EvictionPolicy,
    ManagedCollisionModule,
    MCHManagedCollisionModule,
    average_threshold_filter,  # NOQA
    dynamic_threshold_filter,  # NOQA
    probabilistic_threshold_filter,  # NOQA
)

from tzrec.datasets.utils import (
    BASE_DATA_GROUP,
    C_NEG_SAMPLE_MASK,
    C_SAMPLE_MASK,
    NEG_DATA_GROUP,
    DenseData,
    ParsedData,
    SparseData,
)
from tzrec.modules.dense_embedding_collection import (
    AutoDisEmbeddingConfig,
    DenseEmbeddingConfig,
    MLPDenseEmbeddingConfig,
)
from tzrec.protos.data_pb2 import FgMode
from tzrec.protos.feature_pb2 import FeatureConfig, SequenceFeature
from tzrec.utils import config_util
from tzrec.utils.load_class import get_register_class_meta
from tzrec.utils.logging_util import logger

_FEATURE_CLASS_MAP = {}
_meta_cls = get_register_class_meta(_FEATURE_CLASS_MAP)

MAX_HASH_BUCKET_SIZE = 2**63 - 1


def _parse_fg_encoded_sparse_feature_impl(
    name: str,
    feat: pa.Array,
    multival_sep: str = chr(3),
    default_value: Optional[List[int]] = None,
    is_weighted: bool = False,
) -> SparseData:
    """Parse fg encoded sparse feature.

    Args:
        name (str): feature name.
        feat (pa.Array): input feature data.
        multival_sep (str): string separator for multi-val data.
        default_value (list): default value.
        is_weighted (bool): input feature is weighted or not.

    Returns:
        an instance of SparseData.
    """
    weight_values = None
    if (
        pa.types.is_string(feat.type)
        or pa.types.is_list(feat.type)
        or pa.types.is_map(feat.type)
    ):
        weight = None
        if pa.types.is_string(feat.type) or pa.types.is_list(feat.type):
            if pa.types.is_string(feat.type):
                # dtype = string
                is_empty = pa.compute.equal(feat, pa.scalar(""))
                nulls = pa.nulls(len(feat))
                feat = pa.compute.if_else(is_empty, nulls, feat)
                feat = pa.compute.split_pattern(feat, multival_sep)
            elif pa.types.is_list(feat.type):
                # dtype = list<int> or others can cast to list<int>
                if default_value is not None:
                    is_empty = pa.compute.equal(pa.compute.list_value_length(feat), 0)
                    nulls = pa.nulls(len(feat))
                    feat = pa.compute.if_else(is_empty, nulls, feat)
            if is_weighted:
                assert pa.types.is_string(feat.values.type)
                fw = pa.compute.split_pattern(feat.values, ":")
                weight = pa.ListArray.from_arrays(
                    feat.offsets, fw.values[1::2], mask=feat.is_null()
                )
                feat = pa.ListArray.from_arrays(
                    feat.offsets, fw.values[::2], mask=feat.is_null()
                )
        else:
            # dtype = map<int,float> or others can cast to map<int,float>
            weight = pa.ListArray.from_arrays(
                feat.offsets, feat.items, mask=feat.is_null()
            )
            feat = pa.ListArray.from_arrays(
                feat.offsets, feat.keys, mask=feat.is_null()
            )

        feat = feat.cast(pa.list_(pa.int64()), safe=False)
        if weight is not None:
            weight = weight.cast(pa.list_(pa.float32()), safe=False)

        if default_value is not None:
            feat = feat.fill_null(default_value)
            if weight:
                weight = weight.fill_null([1.0])

        feat_values = feat.values.to_numpy()
        feat_offsets = feat.offsets.to_numpy()
        feat_lengths = feat_offsets[1:] - feat_offsets[:-1]
        if weight is not None:
            weight_values = weight.values.to_numpy()
    elif pa.types.is_integer(feat.type):
        assert not is_weighted
        # dtype = int
        if default_value is not None:
            feat = feat.cast(pa.int64()).fill_null(default_value[0])
            feat_values = feat.to_numpy()
            feat_lengths = np.ones_like(feat_values, np.int32)
        else:
            feat_values = feat.drop_null().cast(pa.int64()).to_numpy()
            feat_lengths = 1 - feat.is_null().cast(pa.int32()).to_numpy()
    else:
        raise ValueError(
            f"{name} only support str|int|list<int>|map<int,double> dtype input, "
            f"but get {feat.type}."
        )
    return SparseData(name, feat_values, feat_lengths, weights=weight_values)


def _parse_fg_encoded_dense_feature_impl(
    name: str,
    feat: pa.Array,
    multival_sep: str = chr(3),
    default_value: Optional[List[float]] = None,
) -> DenseData:
    """Parse fg encoded dense feature.

    Args:
        name (str): feature name.
        feat (npt.NDArray): input feature data.
        multival_sep (str): string separator for multi-val data.
        default_value (list): default value.

    Returns:
        an instance of DenseData.
    """
    if pa.types.is_string(feat.type):
        # dtype = string
        if default_value is not None:
            is_empty = pa.compute.equal(feat, pa.scalar(""))
            feat = pa.compute.if_else(is_empty, pa.nulls(len(feat)), feat)
            feat = feat.fill_null(multival_sep.join(map(str, default_value)))
        list_feat = pa.compute.split_pattern(feat, multival_sep)
        list_feat = list_feat.cast(pa.list_(pa.float32()), safe=False)
        feat_values = np.stack(list_feat.to_numpy(zero_copy_only=False))
    elif pa.types.is_list(feat.type):
        # dtype = list<float> or others can cast to list<float>
        feat = feat.cast(pa.list_(pa.float32()), safe=False)
        if default_value is not None:
            is_empty = pa.compute.equal(pa.compute.list_value_length(feat), 0)
            feat = pa.compute.if_else(is_empty, pa.nulls(len(feat)), feat)
            feat = feat.fill_null(default_value)
        feat_values = np.stack(feat.to_numpy(zero_copy_only=False))
    elif pa.types.is_integer(feat.type) or pa.types.is_floating(feat.type):
        # dtype = int or float
        feat = feat.cast(pa.float32(), safe=False)
        if default_value is not None:
            feat = feat.fill_null(default_value[0])
        feat_values = feat.to_numpy()[:, np.newaxis]
    else:
        raise ValueError(
            f"{name} only support str|int|float|list<float> dtype input,"
            f" but get {feat.type}."
        )
    return DenseData(name, feat_values)


class InvalidFgInputError(Exception):
    """Invalid Feature side inputs exception."""

    pass


class BaseFeature(object, metaclass=_meta_cls):
    """Base feature class.

    Args:
        feature_config (FeatureConfig): a instance of feature config.
        fg_mode (FgMode): input data fg mode.
        fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE
    """

    def __init__(
        self,
        feature_config: FeatureConfig,
        fg_mode: FgMode = FgMode.FG_NONE,
        fg_encoded_multival_sep: Optional[str] = None,
    ) -> None:
        fc_type = feature_config.WhichOneof("feature")
        self._feature_config = feature_config
        self.config = getattr(self._feature_config, fc_type)

        self.fg_mode = fg_mode

        self._fg_op = None
        self._is_neg = False
        self._is_sparse = None
        self._is_weighted = False
        self._is_user_feat = None
        self._data_group = BASE_DATA_GROUP
        self._inputs = None
        self._side_inputs = None
        self._vocab_list = None
        self._vocab_dict = None

        self._fg_encoded_kwargs = {}
        self._fg_encoded_multival_sep = fg_encoded_multival_sep or chr(3)
        if self.fg_mode == FgMode.FG_NONE:
            if self.config.HasField("fg_encoded_default_value"):
                self._fg_encoded_kwargs["default_value"] = (
                    self.fg_encoded_default_value()
                )
            elif self.config.use_mask:
                try:
                    self._fg_encoded_kwargs["default_value"] = (
                        self.fg_encoded_default_value()
                    )
                except Exception:
                    raise RuntimeError(
                        f"when use mask, you should set fg_encoded_default_value"
                        f" for {self.name}"
                    ) from None
            self._fg_encoded_kwargs["multival_sep"] = self._fg_encoded_multival_sep

        if self.fg_mode == FgMode.FG_NORMAL:
            self.init_fg()

    @property
    def name(self) -> str:
        """Feature name."""
        raise NotImplementedError

    @property
    def is_neg(self) -> bool:
        """Feature is negative sampled or not."""
        return self._is_neg

    @is_neg.setter
    def is_neg(self, value: bool) -> None:
        """Feature is negative sampled or not."""
        self._is_neg = value
        self._data_group = NEG_DATA_GROUP

    @property
    def data_group(self) -> str:
        """Data group for the feature."""
        return self._data_group

    @data_group.setter
    def data_group(self, data_group: str) -> None:
        """Data group for the feature."""
        self._data_group = data_group

    @property
    def feature_config(self) -> FeatureConfig:
        """Feature config for the feature."""
        return self._feature_config

    @feature_config.setter
    def feature_config(self, feature_config: FeatureConfig) -> None:
        """Feature config for the feature."""
        fc_type = feature_config.WhichOneof("feature")
        self._feature_config = feature_config
        self.config = getattr(self._feature_config, fc_type)

    @property
    def is_user_feat(self) -> bool:
        """Feature is user side or not."""
        if self._is_user_feat is None:
            # legacy without dag, we may not set is_user_feat
            if self.is_grouped_sequence:
                return True
            for side, _ in self.side_inputs:
                if side != "user":
                    return False
            return True
        else:
            return self._is_user_feat

    @is_user_feat.setter
    def is_user_feat(self, value: bool) -> None:
        """Feature is user side or not."""
        self._is_user_feat = value

    @property
    def value_dim(self) -> int:
        """Fg value dimension of the feature."""
        raise NotImplementedError

    @property
    def output_dim(self) -> int:
        """Output dimension of the feature after embedding."""
        raise NotImplementedError

    @property
    def is_sparse(self) -> bool:
        """Feature is sparse or dense."""
        if self._is_sparse is None:
            self._is_sparse = False
        return self._is_sparse

    @property
    def is_sequence(self) -> bool:
        """Feature is sequence or not."""
        return False

    @property
    def is_grouped_sequence(self) -> bool:
        """Feature is grouped sequence or not."""
        return False

    @property
    def is_weighted(self) -> bool:
        """Feature is weighted id feature or not."""
        return self._is_weighted

    @property
    def has_embedding(self) -> bool:
        """Feature has embedding or not."""
        if self.is_sparse:
            return True
        else:
            return self._dense_emb_type is not None

    @property
    def pooling_type(self) -> PoolingType:
        """Get embedding pooling type."""
        pooling_type = self.config.pooling.upper()
        assert pooling_type in {"SUM", "MEAN"}, "available pooling type is SUM | MEAN"
        return getattr(PoolingType, pooling_type)

    @property
    def num_embeddings(self) -> int:
        """Get embedding row count."""
        raise NotImplementedError

    @property
    def _embedding_dim(self) -> int:
        if self.has_embedding:
            assert self.config.embedding_dim > 0, (
                f"embedding_dim of {self.__class__.__name__}[{self.name}] "
                "should be greater than 0."
            )
        return self.config.embedding_dim

    @property
    def _dense_emb_type(self) -> Optional[str]:
        return None

    @property
    def emb_bag_config(self) -> Optional[EmbeddingBagConfig]:
        """Get EmbeddingBagConfig of the feature."""
        if self.is_sparse:
            embedding_name = self.config.embedding_name or f"{self.name}_emb"
            init_fn = None
            if self.config.HasField("init_fn"):
                init_fn = eval(f"partial({self.config.init_fn})")
            return EmbeddingBagConfig(
                num_embeddings=self.num_embeddings,
                embedding_dim=self._embedding_dim,
                name=embedding_name,
                feature_names=[self.name],
                pooling=self.pooling_type,
                init_fn=init_fn,
            )
        else:
            return None

    @property
    def emb_config(self) -> Optional[EmbeddingConfig]:
        """Get EmbeddingConfig of the feature."""
        if self.is_sparse:
            embedding_name = self.config.embedding_name or f"{self.name}_emb"
            init_fn = None
            if self.config.HasField("init_fn"):
                init_fn = eval(f"partial({self.config.init_fn})")
            return EmbeddingConfig(
                num_embeddings=self.num_embeddings,
                embedding_dim=self._embedding_dim,
                name=embedding_name,
                feature_names=[self.name],
                init_fn=init_fn,
            )
        else:
            return None

    @property
    def dense_emb_config(
        self,
    ) -> Optional[DenseEmbeddingConfig]:
        """Get DenseEmbeddingConfig of the feature."""
        if self._dense_emb_type:
            dense_emb_config = getattr(self.config, self._dense_emb_type)
            assert self.value_dim <= 1, (
                "dense embedding do not support"
                f" feature [{self.name}] with value_dim > 1 now."
            )
            if self._dense_emb_type == "autodis":
                return AutoDisEmbeddingConfig(
                    embedding_dim=self._embedding_dim,
                    n_channels=dense_emb_config.num_channels,
                    temperature=dense_emb_config.temperature,
                    keep_prob=dense_emb_config.keep_prob,
                    feature_names=[self.name],
                )
            elif self._dense_emb_type == "mlp":
                return MLPDenseEmbeddingConfig(
                    embedding_dim=self._embedding_dim,
                    feature_names=[self.name],
                )

        return None

    def mc_module(self, device: torch.device) -> Optional[ManagedCollisionModule]:
        """Get ManagedCollisionModule."""
        if self.is_sparse:
            if hasattr(self.config, "zch") and self.config.HasField("zch"):
                evict_type = self.config.zch.WhichOneof("eviction_policy")
                evict_config = getattr(self.config.zch, evict_type)
                threshold_filtering_func = None
                if self.config.zch.HasField("threshold_filtering_func"):
                    threshold_filtering_func = eval(
                        self.config.zch.threshold_filtering_func
                    )
                if evict_type == "lfu":
                    eviction_policy = LFU_EvictionPolicy(
                        threshold_filtering_func=threshold_filtering_func
                    )
                elif evict_type == "lru":
                    eviction_policy = LRU_EvictionPolicy(
                        decay_exponent=evict_config.decay_exponent,
                        threshold_filtering_func=threshold_filtering_func,
                    )
                elif evict_type == "distance_lfu":
                    eviction_policy = DistanceLFU_EvictionPolicy(
                        decay_exponent=evict_config.decay_exponent,
                        threshold_filtering_func=threshold_filtering_func,
                    )
                else:
                    raise ValueError("Unknown evict policy type: {evict_type}")
                return MCHManagedCollisionModule(
                    zch_size=self.config.zch.zch_size,
                    device=device,
                    eviction_interval=self.config.zch.eviction_interval,
                    eviction_policy=eviction_policy,
                )
        return None

    @property
    def inputs(self) -> List[str]:
        """Input field names."""
        if not self._inputs:
            if self.fg_mode in [FgMode.FG_NONE, FgMode.FG_BUCKETIZE]:
                self._inputs = [self.name]
            else:
                self._inputs = [v for _, v in self.side_inputs]
        return self._inputs

    @property
    def side_inputs(self) -> List[Tuple[str, str]]:
        """Input field names with side."""
        if self._side_inputs is None:
            side_inputs = self._build_side_inputs()
            if not side_inputs:
                raise InvalidFgInputError(
                    f"{self.__class__.__name__}[{self.name}] must have fg "
                    f"input names, e.g., item:cat_a."
                )
            for x in side_inputs:
                if not (len(x) == 2 and x[0] in ["user", "item", "context", "feature"]):
                    raise InvalidFgInputError(
                        f"{self.__class__.__name__}[{self.name}] must have valid fg "
                        f"input names, e.g., item:cat_a, but got {x}."
                    )
            self._side_inputs = side_inputs
        return self._side_inputs

    def _build_side_inputs(self) -> Optional[List[Tuple[str, str]]]:
        """Build input field names with side."""
        return NotImplemented

    def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData:
        """Parse input data for the feature impl.

        Args:
            input_data (dict): raw input feature data.

        Return:
            parsed feature data.
        """
        raise NotImplementedError

    def parse(
        self, input_data: Dict[str, pa.Array], is_training: bool = False
    ) -> ParsedData:
        """Parse input data for the feature.

        Args:
            input_data (dict): raw input feature data.
            is_training (bool): is training or not.

        Return:
            parsed feature data.
        """
        if is_training and self.config.use_mask:
            t_input_data = {}
            i = 0
            for name in self.inputs:
                data = input_data[name]
                if i == 0 and not pa.types.is_map(data.type):
                    mask = (
                        input_data[C_NEG_SAMPLE_MASK]
                        if self.is_neg
                        else input_data[C_SAMPLE_MASK]
                    )
                    data = pa.compute.if_else(mask, pa.nulls(len(data)), data)
                    i += 1
                t_input_data[name] = data
        else:
            t_input_data = input_data

        parsed_data = self._parse(t_input_data)
        return parsed_data

    def init_fg(self) -> None:
        """Init fg op."""
        if self._fg_op is None:
            cfgs = self.fg_json()
            if len(cfgs) > 1:
                # pyre-ignore [16]
                self._fg_op = pyfg.FgHandler({"features": cfgs}, 1)
            else:
                is_rank_zero = os.environ.get("RANK", "0") == "0"
                # pyre-ignore [16]
                self._fg_op = pyfg.FeatureFactory.create(cfgs[0], is_rank_zero)

    def fg_json(self) -> List[Dict[str, Any]]:
        """Get fg json config."""
        raise NotImplementedError

    def fg_encoded_default_value(self) -> Optional[Union[List[int], List[float]]]:
        """Get fg encoded default value."""
        if self.config.HasField("fg_encoded_default_value"):
            if self.config.fg_encoded_default_value == "":
                return None
            if self.is_sparse:
                return list(
                    map(
                        int,
                        self.config.fg_encoded_default_value.split(
                            self._fg_encoded_multival_sep
                        ),
                    )
                )
            else:
                return list(
                    map(
                        float,
                        self.config.fg_encoded_default_value.split(
                            self._fg_encoded_multival_sep
                        ),
                    )
                )
        else:
            # we try to initialize fg to get fg_encoded_default_value
            self.init_fg()
            # pyre-ignore [16]
            if isinstance(self._fg_op, pyfg.FgHandler):
                output, status = self._fg_op({x: [None] for _, x in self.side_inputs})
                assert status.ok(), status.message()
                default_value = output[self.name][0]
                self._fg_op.reset_executor()
            else:
                output = self._fg_op([[None] for _ in self.side_inputs])
                default_value = output[0]
            if default_value is not None:
                if not isinstance(default_value, list):
                    default_value = [default_value]
                elif len(default_value) == 0:
                    # empty list
                    default_value = None
                elif isinstance(default_value[0], list):
                    # list of list
                    default_value = default_value[0]
            return default_value

    @property
    def vocab_list(self) -> List[str]:
        """Vocab list."""
        if self._vocab_list is None:
            if len(self.config.vocab_list) > 0:
                if self.config.HasField("default_bucketize_value"):
                    # when set default_bucketize_value, we do not add additional
                    # `default_value` and <OOV> vocab to vocab_list
                    assert self.config.default_bucketize_value < len(
                        self.config.vocab_list
                    ), (
                        "default_bucketize_value should be less than len(vocab_list) "
                        f"in {self.__class__.__name__}[{self.name}]"
                    )
                    self._vocab_list = list(self.config.vocab_list)
                else:
                    self._vocab_list = [self.config.default_value, "<OOV>"] + list(
                        self.config.vocab_list
                    )
            else:
                self._vocab_list = []
        return self._vocab_list

    @property
    def vocab_dict(self) -> Dict[str, int]:
        """Vocab dict."""
        if self._vocab_dict is None:
            if len(self.config.vocab_dict) > 0:
                vocab_dict = OrderedDict(self.config.vocab_dict.items())
                if self.config.HasField("default_bucketize_value"):
                    # when set default_bucketize_value, we do not add additional
                    # `default_value` and <OOV> vocab to vocab_dict
                    self._vocab_dict = vocab_dict
                else:
                    is_rank_zero = os.environ.get("RANK", "0") == "0"
                    if min(list(self.config.vocab_dict.values())) <= 1 and is_rank_zero:
                        logger.warn(
                            "min index of vocab_dict in "
                            f"{self.__class__.__name__}[{self.name}] should "
                            "start from 2. index0 is default_value, index1 is <OOV>."
                        )
                    vocab_dict[self.config.default_value] = 0
                    self._vocab_dict = vocab_dict
            else:
                self._vocab_dict = {}
        return self._vocab_dict

    @property
    def vocab_file(self) -> str:
        """Vocab file."""
        if self.config.HasField("vocab_file"):
            if not self.config.HasField("default_bucketize_value"):
                raise ValueError(
                    "default_bucketize_value must be set when use vocab_file."
                )
            vocab_file = self.config.vocab_file
            if self.config.HasField("asset_dir"):
                vocab_file = os.path.join(self.config.asset_dir, vocab_file)
            return vocab_file
        else:
            return ""

    @property
    def default_bucketize_value(self) -> int:
        """Default bucketize value."""
        if self.config.HasField("default_bucketize_value"):
            return self.config.default_bucketize_value
        else:
            return 1

    def assets(self) -> Dict[str, str]:
        """Asset file paths."""
        return {}

    def __del__(self) -> None:
        # pyre-ignore [16]
        if self._fg_op and isinstance(self._fg_op, pyfg.FgHandler):
            self._fg_op.reset_executor()


def create_features(
    feature_configs: List[FeatureConfig],
    fg_mode: FgMode = FgMode.FG_NONE,
    neg_fields: Optional[List[str]] = None,
    fg_encoded_multival_sep: Optional[str] = None,
    force_base_data_group: bool = False,
) -> List[BaseFeature]:
    """Build feature list from feature config.

    Args:
        feature_configs (list): list of feature_config.
        fg_mode (FgMode): input data fg mode.
        neg_fields (list, optional): negative sampled input fields.
        fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE
        force_base_data_group (bool): force padding data into same
            data group with same batch_size.

    Return:
        features: list of Feature.
    """
    features = []
    for feat_config in feature_configs:
        oneof_feat_config = getattr(feat_config, feat_config.WhichOneof("feature"))
        feat_cls_name = oneof_feat_config.__class__.__name__
        if feat_cls_name == "SequenceFeature":
            sequence_name = oneof_feat_config.sequence_name
            sequence_delim = oneof_feat_config.sequence_delim
            sequence_length = oneof_feat_config.sequence_length
            sequence_pk = oneof_feat_config.sequence_pk
            for sub_feat_config in oneof_feat_config.features:
                sub_feat_cls_name = config_util.which_msg(sub_feat_config, "feature")
                # pyre-ignore [16]
                feature = BaseFeature.create_class(f"Sequence{sub_feat_cls_name}")(
                    sub_feat_config,
                    sequence_name=sequence_name,
                    sequence_delim=sequence_delim,
                    sequence_length=sequence_length,
                    sequence_pk=sequence_pk,
                    fg_mode=fg_mode,
                    fg_encoded_multival_sep=fg_encoded_multival_sep,
                )
                features.append(feature)
        else:
            feature = BaseFeature.create_class(feat_cls_name)(
                feat_config,
                fg_mode=fg_mode,
                fg_encoded_multival_sep=fg_encoded_multival_sep,
            )
            features.append(feature)

    has_dag = False
    for feature in features:
        if neg_fields:
            if len(set(feature.inputs) & set(neg_fields)):
                feature.is_neg = True
        if force_base_data_group:
            feature.data_group = BASE_DATA_GROUP
        try:
            side_inputs = feature.side_inputs
            for k, _ in side_inputs:
                if k == "feature":
                    has_dag = True
                    break
        except InvalidFgInputError:
            pass

    if has_dag:
        fg_json = create_fg_json(features)
        # pyre-ignore [16]
        fg_handler = pyfg.FgArrowHandler(fg_json, 1)
        user_feats = fg_handler.user_features() | set(
            fg_handler.sequence_feature_to_name().keys()
        )
        for feature in features:
            feature.is_user_feat = feature.name in user_feats

    return features


def _copy_assets(
    feature: BaseFeature,
    asset_dir: Optional[str] = None,
    use_relative_asset_dir: bool = False,
) -> BaseFeature:
    if asset_dir and len(feature.assets()) > 0:
        # deepcopy feature config
        feature_config = type(feature.feature_config)()
        feature_config.CopyFrom(feature.feature_config)
        feature = copy(feature)
        feature.feature_config = feature_config
        for k, v in feature.assets().items():
            with open(v, "rb") as f:
                fhash = hashlib.md5(f.read()).hexdigest()
            fprefix, fext = os.path.splitext(os.path.basename(v))
            fname = f"{fprefix}_{fhash}{fext}"
            fpath = os.path.join(asset_dir, fname)
            if not os.path.exists(fpath):
                shutil.copy(v, fpath)
            config_util.edit_config(feature.config, {k: fname})
            if not use_relative_asset_dir:
                feature.config.asset_dir = asset_dir
            else:
                feature.config.ClearField("asset_dir")
    return feature


def create_fg_json(
    features: List[BaseFeature], asset_dir: Optional[str] = None
) -> Dict[str, Any]:
    """Create feature generate config for features."""
    results = []
    seq_to_idx = {}
    for feature in features:
        feature = _copy_assets(feature, asset_dir, use_relative_asset_dir=True)
        if feature.is_grouped_sequence:
            # pyre-ignore [16]
            if feature.sequence_name not in seq_to_idx:
                results.append(
                    {
                        "sequence_name": feature.sequence_name,
                        "sequence_length": feature.sequence_length,  # pyre-ignore [16]
                        "sequence_delim": feature.sequence_delim,  # pyre-ignore [16]
                        "sequence_pk": feature.sequence_pk,  # pyre-ignore [16]
                        "features": [],
                    }
                )
                seq_to_idx[feature.sequence_name] = len(results) - 1
            fg_json = feature.fg_json()
            idx = seq_to_idx[feature.sequence_name]
            results[idx]["features"].extend(fg_json)
        else:
            fg_json = feature.fg_json()
            results.extend(fg_json)
    return {"features": results}


def create_feature_configs(
    features: List[BaseFeature], asset_dir: Optional[str] = None
) -> List[FeatureConfig]:
    """Create feature configs for features."""
    results = OrderedDict()
    for feature in features:
        feature = _copy_assets(feature, asset_dir)
        if feature.is_grouped_sequence:
            # pyre-ignore [16]
            if feature.sequence_name not in results:
                results[feature.sequence_name] = FeatureConfig(
                    sequence_feature=SequenceFeature(
                        sequence_name=feature.sequence_name,
                        sequence_length=feature.sequence_length,  # pyre-ignore [16]
                        sequence_delim=feature.sequence_delim,  # pyre-ignore [16]
                        sequence_pk=feature.sequence_pk,  # pyre-ignore [16]
                    )
                )
            results[feature.sequence_name].sequence_feature.features.append(
                feature.feature_config
            )
        else:
            results[feature.name] = feature.feature_config
    return list(results.values())
