tzrec/features/feature.py (659 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import os
import shutil
from collections import OrderedDict
from copy import copy
from functools import partial # NOQA
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import pyarrow as pa
import pyfg
import torch
from torch import nn # NOQA
from torchrec.modules.embedding_configs import (
EmbeddingBagConfig,
EmbeddingConfig,
PoolingType,
)
from torchrec.modules.mc_modules import (
DistanceLFU_EvictionPolicy,
LFU_EvictionPolicy,
LRU_EvictionPolicy,
ManagedCollisionModule,
MCHManagedCollisionModule,
average_threshold_filter, # NOQA
dynamic_threshold_filter, # NOQA
probabilistic_threshold_filter, # NOQA
)
from tzrec.datasets.utils import (
BASE_DATA_GROUP,
C_NEG_SAMPLE_MASK,
C_SAMPLE_MASK,
NEG_DATA_GROUP,
DenseData,
ParsedData,
SparseData,
)
from tzrec.modules.dense_embedding_collection import (
AutoDisEmbeddingConfig,
DenseEmbeddingConfig,
MLPDenseEmbeddingConfig,
)
from tzrec.protos.data_pb2 import FgMode
from tzrec.protos.feature_pb2 import FeatureConfig, SequenceFeature
from tzrec.utils import config_util
from tzrec.utils.load_class import get_register_class_meta
from tzrec.utils.logging_util import logger
_FEATURE_CLASS_MAP = {}
_meta_cls = get_register_class_meta(_FEATURE_CLASS_MAP)
MAX_HASH_BUCKET_SIZE = 2**63 - 1
def _parse_fg_encoded_sparse_feature_impl(
name: str,
feat: pa.Array,
multival_sep: str = chr(3),
default_value: Optional[List[int]] = None,
is_weighted: bool = False,
) -> SparseData:
"""Parse fg encoded sparse feature.
Args:
name (str): feature name.
feat (pa.Array): input feature data.
multival_sep (str): string separator for multi-val data.
default_value (list): default value.
is_weighted (bool): input feature is weighted or not.
Returns:
an instance of SparseData.
"""
weight_values = None
if (
pa.types.is_string(feat.type)
or pa.types.is_list(feat.type)
or pa.types.is_map(feat.type)
):
weight = None
if pa.types.is_string(feat.type) or pa.types.is_list(feat.type):
if pa.types.is_string(feat.type):
# dtype = string
is_empty = pa.compute.equal(feat, pa.scalar(""))
nulls = pa.nulls(len(feat))
feat = pa.compute.if_else(is_empty, nulls, feat)
feat = pa.compute.split_pattern(feat, multival_sep)
elif pa.types.is_list(feat.type):
# dtype = list<int> or others can cast to list<int>
if default_value is not None:
is_empty = pa.compute.equal(pa.compute.list_value_length(feat), 0)
nulls = pa.nulls(len(feat))
feat = pa.compute.if_else(is_empty, nulls, feat)
if is_weighted:
assert pa.types.is_string(feat.values.type)
fw = pa.compute.split_pattern(feat.values, ":")
weight = pa.ListArray.from_arrays(
feat.offsets, fw.values[1::2], mask=feat.is_null()
)
feat = pa.ListArray.from_arrays(
feat.offsets, fw.values[::2], mask=feat.is_null()
)
else:
# dtype = map<int,float> or others can cast to map<int,float>
weight = pa.ListArray.from_arrays(
feat.offsets, feat.items, mask=feat.is_null()
)
feat = pa.ListArray.from_arrays(
feat.offsets, feat.keys, mask=feat.is_null()
)
feat = feat.cast(pa.list_(pa.int64()), safe=False)
if weight is not None:
weight = weight.cast(pa.list_(pa.float32()), safe=False)
if default_value is not None:
feat = feat.fill_null(default_value)
if weight:
weight = weight.fill_null([1.0])
feat_values = feat.values.to_numpy()
feat_offsets = feat.offsets.to_numpy()
feat_lengths = feat_offsets[1:] - feat_offsets[:-1]
if weight is not None:
weight_values = weight.values.to_numpy()
elif pa.types.is_integer(feat.type):
assert not is_weighted
# dtype = int
if default_value is not None:
feat = feat.cast(pa.int64()).fill_null(default_value[0])
feat_values = feat.to_numpy()
feat_lengths = np.ones_like(feat_values, np.int32)
else:
feat_values = feat.drop_null().cast(pa.int64()).to_numpy()
feat_lengths = 1 - feat.is_null().cast(pa.int32()).to_numpy()
else:
raise ValueError(
f"{name} only support str|int|list<int>|map<int,double> dtype input, "
f"but get {feat.type}."
)
return SparseData(name, feat_values, feat_lengths, weights=weight_values)
def _parse_fg_encoded_dense_feature_impl(
name: str,
feat: pa.Array,
multival_sep: str = chr(3),
default_value: Optional[List[float]] = None,
) -> DenseData:
"""Parse fg encoded dense feature.
Args:
name (str): feature name.
feat (npt.NDArray): input feature data.
multival_sep (str): string separator for multi-val data.
default_value (list): default value.
Returns:
an instance of DenseData.
"""
if pa.types.is_string(feat.type):
# dtype = string
if default_value is not None:
is_empty = pa.compute.equal(feat, pa.scalar(""))
feat = pa.compute.if_else(is_empty, pa.nulls(len(feat)), feat)
feat = feat.fill_null(multival_sep.join(map(str, default_value)))
list_feat = pa.compute.split_pattern(feat, multival_sep)
list_feat = list_feat.cast(pa.list_(pa.float32()), safe=False)
feat_values = np.stack(list_feat.to_numpy(zero_copy_only=False))
elif pa.types.is_list(feat.type):
# dtype = list<float> or others can cast to list<float>
feat = feat.cast(pa.list_(pa.float32()), safe=False)
if default_value is not None:
is_empty = pa.compute.equal(pa.compute.list_value_length(feat), 0)
feat = pa.compute.if_else(is_empty, pa.nulls(len(feat)), feat)
feat = feat.fill_null(default_value)
feat_values = np.stack(feat.to_numpy(zero_copy_only=False))
elif pa.types.is_integer(feat.type) or pa.types.is_floating(feat.type):
# dtype = int or float
feat = feat.cast(pa.float32(), safe=False)
if default_value is not None:
feat = feat.fill_null(default_value[0])
feat_values = feat.to_numpy()[:, np.newaxis]
else:
raise ValueError(
f"{name} only support str|int|float|list<float> dtype input,"
f" but get {feat.type}."
)
return DenseData(name, feat_values)
class InvalidFgInputError(Exception):
"""Invalid Feature side inputs exception."""
pass
class BaseFeature(object, metaclass=_meta_cls):
"""Base feature class.
Args:
feature_config (FeatureConfig): a instance of feature config.
fg_mode (FgMode): input data fg mode.
fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE
"""
def __init__(
self,
feature_config: FeatureConfig,
fg_mode: FgMode = FgMode.FG_NONE,
fg_encoded_multival_sep: Optional[str] = None,
) -> None:
fc_type = feature_config.WhichOneof("feature")
self._feature_config = feature_config
self.config = getattr(self._feature_config, fc_type)
self.fg_mode = fg_mode
self._fg_op = None
self._is_neg = False
self._is_sparse = None
self._is_weighted = False
self._is_user_feat = None
self._data_group = BASE_DATA_GROUP
self._inputs = None
self._side_inputs = None
self._vocab_list = None
self._vocab_dict = None
self._fg_encoded_kwargs = {}
self._fg_encoded_multival_sep = fg_encoded_multival_sep or chr(3)
if self.fg_mode == FgMode.FG_NONE:
if self.config.HasField("fg_encoded_default_value"):
self._fg_encoded_kwargs["default_value"] = (
self.fg_encoded_default_value()
)
elif self.config.use_mask:
try:
self._fg_encoded_kwargs["default_value"] = (
self.fg_encoded_default_value()
)
except Exception:
raise RuntimeError(
f"when use mask, you should set fg_encoded_default_value"
f" for {self.name}"
) from None
self._fg_encoded_kwargs["multival_sep"] = self._fg_encoded_multival_sep
if self.fg_mode == FgMode.FG_NORMAL:
self.init_fg()
@property
def name(self) -> str:
"""Feature name."""
raise NotImplementedError
@property
def is_neg(self) -> bool:
"""Feature is negative sampled or not."""
return self._is_neg
@is_neg.setter
def is_neg(self, value: bool) -> None:
"""Feature is negative sampled or not."""
self._is_neg = value
self._data_group = NEG_DATA_GROUP
@property
def data_group(self) -> str:
"""Data group for the feature."""
return self._data_group
@data_group.setter
def data_group(self, data_group: str) -> None:
"""Data group for the feature."""
self._data_group = data_group
@property
def feature_config(self) -> FeatureConfig:
"""Feature config for the feature."""
return self._feature_config
@feature_config.setter
def feature_config(self, feature_config: FeatureConfig) -> None:
"""Feature config for the feature."""
fc_type = feature_config.WhichOneof("feature")
self._feature_config = feature_config
self.config = getattr(self._feature_config, fc_type)
@property
def is_user_feat(self) -> bool:
"""Feature is user side or not."""
if self._is_user_feat is None:
# legacy without dag, we may not set is_user_feat
if self.is_grouped_sequence:
return True
for side, _ in self.side_inputs:
if side != "user":
return False
return True
else:
return self._is_user_feat
@is_user_feat.setter
def is_user_feat(self, value: bool) -> None:
"""Feature is user side or not."""
self._is_user_feat = value
@property
def value_dim(self) -> int:
"""Fg value dimension of the feature."""
raise NotImplementedError
@property
def output_dim(self) -> int:
"""Output dimension of the feature after embedding."""
raise NotImplementedError
@property
def is_sparse(self) -> bool:
"""Feature is sparse or dense."""
if self._is_sparse is None:
self._is_sparse = False
return self._is_sparse
@property
def is_sequence(self) -> bool:
"""Feature is sequence or not."""
return False
@property
def is_grouped_sequence(self) -> bool:
"""Feature is grouped sequence or not."""
return False
@property
def is_weighted(self) -> bool:
"""Feature is weighted id feature or not."""
return self._is_weighted
@property
def has_embedding(self) -> bool:
"""Feature has embedding or not."""
if self.is_sparse:
return True
else:
return self._dense_emb_type is not None
@property
def pooling_type(self) -> PoolingType:
"""Get embedding pooling type."""
pooling_type = self.config.pooling.upper()
assert pooling_type in {"SUM", "MEAN"}, "available pooling type is SUM | MEAN"
return getattr(PoolingType, pooling_type)
@property
def num_embeddings(self) -> int:
"""Get embedding row count."""
raise NotImplementedError
@property
def _embedding_dim(self) -> int:
if self.has_embedding:
assert self.config.embedding_dim > 0, (
f"embedding_dim of {self.__class__.__name__}[{self.name}] "
"should be greater than 0."
)
return self.config.embedding_dim
@property
def _dense_emb_type(self) -> Optional[str]:
return None
@property
def emb_bag_config(self) -> Optional[EmbeddingBagConfig]:
"""Get EmbeddingBagConfig of the feature."""
if self.is_sparse:
embedding_name = self.config.embedding_name or f"{self.name}_emb"
init_fn = None
if self.config.HasField("init_fn"):
init_fn = eval(f"partial({self.config.init_fn})")
return EmbeddingBagConfig(
num_embeddings=self.num_embeddings,
embedding_dim=self._embedding_dim,
name=embedding_name,
feature_names=[self.name],
pooling=self.pooling_type,
init_fn=init_fn,
)
else:
return None
@property
def emb_config(self) -> Optional[EmbeddingConfig]:
"""Get EmbeddingConfig of the feature."""
if self.is_sparse:
embedding_name = self.config.embedding_name or f"{self.name}_emb"
init_fn = None
if self.config.HasField("init_fn"):
init_fn = eval(f"partial({self.config.init_fn})")
return EmbeddingConfig(
num_embeddings=self.num_embeddings,
embedding_dim=self._embedding_dim,
name=embedding_name,
feature_names=[self.name],
init_fn=init_fn,
)
else:
return None
@property
def dense_emb_config(
self,
) -> Optional[DenseEmbeddingConfig]:
"""Get DenseEmbeddingConfig of the feature."""
if self._dense_emb_type:
dense_emb_config = getattr(self.config, self._dense_emb_type)
assert self.value_dim <= 1, (
"dense embedding do not support"
f" feature [{self.name}] with value_dim > 1 now."
)
if self._dense_emb_type == "autodis":
return AutoDisEmbeddingConfig(
embedding_dim=self._embedding_dim,
n_channels=dense_emb_config.num_channels,
temperature=dense_emb_config.temperature,
keep_prob=dense_emb_config.keep_prob,
feature_names=[self.name],
)
elif self._dense_emb_type == "mlp":
return MLPDenseEmbeddingConfig(
embedding_dim=self._embedding_dim,
feature_names=[self.name],
)
return None
def mc_module(self, device: torch.device) -> Optional[ManagedCollisionModule]:
"""Get ManagedCollisionModule."""
if self.is_sparse:
if hasattr(self.config, "zch") and self.config.HasField("zch"):
evict_type = self.config.zch.WhichOneof("eviction_policy")
evict_config = getattr(self.config.zch, evict_type)
threshold_filtering_func = None
if self.config.zch.HasField("threshold_filtering_func"):
threshold_filtering_func = eval(
self.config.zch.threshold_filtering_func
)
if evict_type == "lfu":
eviction_policy = LFU_EvictionPolicy(
threshold_filtering_func=threshold_filtering_func
)
elif evict_type == "lru":
eviction_policy = LRU_EvictionPolicy(
decay_exponent=evict_config.decay_exponent,
threshold_filtering_func=threshold_filtering_func,
)
elif evict_type == "distance_lfu":
eviction_policy = DistanceLFU_EvictionPolicy(
decay_exponent=evict_config.decay_exponent,
threshold_filtering_func=threshold_filtering_func,
)
else:
raise ValueError("Unknown evict policy type: {evict_type}")
return MCHManagedCollisionModule(
zch_size=self.config.zch.zch_size,
device=device,
eviction_interval=self.config.zch.eviction_interval,
eviction_policy=eviction_policy,
)
return None
@property
def inputs(self) -> List[str]:
"""Input field names."""
if not self._inputs:
if self.fg_mode in [FgMode.FG_NONE, FgMode.FG_BUCKETIZE]:
self._inputs = [self.name]
else:
self._inputs = [v for _, v in self.side_inputs]
return self._inputs
@property
def side_inputs(self) -> List[Tuple[str, str]]:
"""Input field names with side."""
if self._side_inputs is None:
side_inputs = self._build_side_inputs()
if not side_inputs:
raise InvalidFgInputError(
f"{self.__class__.__name__}[{self.name}] must have fg "
f"input names, e.g., item:cat_a."
)
for x in side_inputs:
if not (len(x) == 2 and x[0] in ["user", "item", "context", "feature"]):
raise InvalidFgInputError(
f"{self.__class__.__name__}[{self.name}] must have valid fg "
f"input names, e.g., item:cat_a, but got {x}."
)
self._side_inputs = side_inputs
return self._side_inputs
def _build_side_inputs(self) -> Optional[List[Tuple[str, str]]]:
"""Build input field names with side."""
return NotImplemented
def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData:
"""Parse input data for the feature impl.
Args:
input_data (dict): raw input feature data.
Return:
parsed feature data.
"""
raise NotImplementedError
def parse(
self, input_data: Dict[str, pa.Array], is_training: bool = False
) -> ParsedData:
"""Parse input data for the feature.
Args:
input_data (dict): raw input feature data.
is_training (bool): is training or not.
Return:
parsed feature data.
"""
if is_training and self.config.use_mask:
t_input_data = {}
i = 0
for name in self.inputs:
data = input_data[name]
if i == 0 and not pa.types.is_map(data.type):
mask = (
input_data[C_NEG_SAMPLE_MASK]
if self.is_neg
else input_data[C_SAMPLE_MASK]
)
data = pa.compute.if_else(mask, pa.nulls(len(data)), data)
i += 1
t_input_data[name] = data
else:
t_input_data = input_data
parsed_data = self._parse(t_input_data)
return parsed_data
def init_fg(self) -> None:
"""Init fg op."""
if self._fg_op is None:
cfgs = self.fg_json()
if len(cfgs) > 1:
# pyre-ignore [16]
self._fg_op = pyfg.FgHandler({"features": cfgs}, 1)
else:
is_rank_zero = os.environ.get("RANK", "0") == "0"
# pyre-ignore [16]
self._fg_op = pyfg.FeatureFactory.create(cfgs[0], is_rank_zero)
def fg_json(self) -> List[Dict[str, Any]]:
"""Get fg json config."""
raise NotImplementedError
def fg_encoded_default_value(self) -> Optional[Union[List[int], List[float]]]:
"""Get fg encoded default value."""
if self.config.HasField("fg_encoded_default_value"):
if self.config.fg_encoded_default_value == "":
return None
if self.is_sparse:
return list(
map(
int,
self.config.fg_encoded_default_value.split(
self._fg_encoded_multival_sep
),
)
)
else:
return list(
map(
float,
self.config.fg_encoded_default_value.split(
self._fg_encoded_multival_sep
),
)
)
else:
# we try to initialize fg to get fg_encoded_default_value
self.init_fg()
# pyre-ignore [16]
if isinstance(self._fg_op, pyfg.FgHandler):
output, status = self._fg_op({x: [None] for _, x in self.side_inputs})
assert status.ok(), status.message()
default_value = output[self.name][0]
self._fg_op.reset_executor()
else:
output = self._fg_op([[None] for _ in self.side_inputs])
default_value = output[0]
if default_value is not None:
if not isinstance(default_value, list):
default_value = [default_value]
elif len(default_value) == 0:
# empty list
default_value = None
elif isinstance(default_value[0], list):
# list of list
default_value = default_value[0]
return default_value
@property
def vocab_list(self) -> List[str]:
"""Vocab list."""
if self._vocab_list is None:
if len(self.config.vocab_list) > 0:
if self.config.HasField("default_bucketize_value"):
# when set default_bucketize_value, we do not add additional
# `default_value` and <OOV> vocab to vocab_list
assert self.config.default_bucketize_value < len(
self.config.vocab_list
), (
"default_bucketize_value should be less than len(vocab_list) "
f"in {self.__class__.__name__}[{self.name}]"
)
self._vocab_list = list(self.config.vocab_list)
else:
self._vocab_list = [self.config.default_value, "<OOV>"] + list(
self.config.vocab_list
)
else:
self._vocab_list = []
return self._vocab_list
@property
def vocab_dict(self) -> Dict[str, int]:
"""Vocab dict."""
if self._vocab_dict is None:
if len(self.config.vocab_dict) > 0:
vocab_dict = OrderedDict(self.config.vocab_dict.items())
if self.config.HasField("default_bucketize_value"):
# when set default_bucketize_value, we do not add additional
# `default_value` and <OOV> vocab to vocab_dict
self._vocab_dict = vocab_dict
else:
is_rank_zero = os.environ.get("RANK", "0") == "0"
if min(list(self.config.vocab_dict.values())) <= 1 and is_rank_zero:
logger.warn(
"min index of vocab_dict in "
f"{self.__class__.__name__}[{self.name}] should "
"start from 2. index0 is default_value, index1 is <OOV>."
)
vocab_dict[self.config.default_value] = 0
self._vocab_dict = vocab_dict
else:
self._vocab_dict = {}
return self._vocab_dict
@property
def vocab_file(self) -> str:
"""Vocab file."""
if self.config.HasField("vocab_file"):
if not self.config.HasField("default_bucketize_value"):
raise ValueError(
"default_bucketize_value must be set when use vocab_file."
)
vocab_file = self.config.vocab_file
if self.config.HasField("asset_dir"):
vocab_file = os.path.join(self.config.asset_dir, vocab_file)
return vocab_file
else:
return ""
@property
def default_bucketize_value(self) -> int:
"""Default bucketize value."""
if self.config.HasField("default_bucketize_value"):
return self.config.default_bucketize_value
else:
return 1
def assets(self) -> Dict[str, str]:
"""Asset file paths."""
return {}
def __del__(self) -> None:
# pyre-ignore [16]
if self._fg_op and isinstance(self._fg_op, pyfg.FgHandler):
self._fg_op.reset_executor()
def create_features(
feature_configs: List[FeatureConfig],
fg_mode: FgMode = FgMode.FG_NONE,
neg_fields: Optional[List[str]] = None,
fg_encoded_multival_sep: Optional[str] = None,
force_base_data_group: bool = False,
) -> List[BaseFeature]:
"""Build feature list from feature config.
Args:
feature_configs (list): list of feature_config.
fg_mode (FgMode): input data fg mode.
neg_fields (list, optional): negative sampled input fields.
fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE
force_base_data_group (bool): force padding data into same
data group with same batch_size.
Return:
features: list of Feature.
"""
features = []
for feat_config in feature_configs:
oneof_feat_config = getattr(feat_config, feat_config.WhichOneof("feature"))
feat_cls_name = oneof_feat_config.__class__.__name__
if feat_cls_name == "SequenceFeature":
sequence_name = oneof_feat_config.sequence_name
sequence_delim = oneof_feat_config.sequence_delim
sequence_length = oneof_feat_config.sequence_length
sequence_pk = oneof_feat_config.sequence_pk
for sub_feat_config in oneof_feat_config.features:
sub_feat_cls_name = config_util.which_msg(sub_feat_config, "feature")
# pyre-ignore [16]
feature = BaseFeature.create_class(f"Sequence{sub_feat_cls_name}")(
sub_feat_config,
sequence_name=sequence_name,
sequence_delim=sequence_delim,
sequence_length=sequence_length,
sequence_pk=sequence_pk,
fg_mode=fg_mode,
fg_encoded_multival_sep=fg_encoded_multival_sep,
)
features.append(feature)
else:
feature = BaseFeature.create_class(feat_cls_name)(
feat_config,
fg_mode=fg_mode,
fg_encoded_multival_sep=fg_encoded_multival_sep,
)
features.append(feature)
has_dag = False
for feature in features:
if neg_fields:
if len(set(feature.inputs) & set(neg_fields)):
feature.is_neg = True
if force_base_data_group:
feature.data_group = BASE_DATA_GROUP
try:
side_inputs = feature.side_inputs
for k, _ in side_inputs:
if k == "feature":
has_dag = True
break
except InvalidFgInputError:
pass
if has_dag:
fg_json = create_fg_json(features)
# pyre-ignore [16]
fg_handler = pyfg.FgArrowHandler(fg_json, 1)
user_feats = fg_handler.user_features() | set(
fg_handler.sequence_feature_to_name().keys()
)
for feature in features:
feature.is_user_feat = feature.name in user_feats
return features
def _copy_assets(
feature: BaseFeature,
asset_dir: Optional[str] = None,
use_relative_asset_dir: bool = False,
) -> BaseFeature:
if asset_dir and len(feature.assets()) > 0:
# deepcopy feature config
feature_config = type(feature.feature_config)()
feature_config.CopyFrom(feature.feature_config)
feature = copy(feature)
feature.feature_config = feature_config
for k, v in feature.assets().items():
with open(v, "rb") as f:
fhash = hashlib.md5(f.read()).hexdigest()
fprefix, fext = os.path.splitext(os.path.basename(v))
fname = f"{fprefix}_{fhash}{fext}"
fpath = os.path.join(asset_dir, fname)
if not os.path.exists(fpath):
shutil.copy(v, fpath)
config_util.edit_config(feature.config, {k: fname})
if not use_relative_asset_dir:
feature.config.asset_dir = asset_dir
else:
feature.config.ClearField("asset_dir")
return feature
def create_fg_json(
features: List[BaseFeature], asset_dir: Optional[str] = None
) -> Dict[str, Any]:
"""Create feature generate config for features."""
results = []
seq_to_idx = {}
for feature in features:
feature = _copy_assets(feature, asset_dir, use_relative_asset_dir=True)
if feature.is_grouped_sequence:
# pyre-ignore [16]
if feature.sequence_name not in seq_to_idx:
results.append(
{
"sequence_name": feature.sequence_name,
"sequence_length": feature.sequence_length, # pyre-ignore [16]
"sequence_delim": feature.sequence_delim, # pyre-ignore [16]
"sequence_pk": feature.sequence_pk, # pyre-ignore [16]
"features": [],
}
)
seq_to_idx[feature.sequence_name] = len(results) - 1
fg_json = feature.fg_json()
idx = seq_to_idx[feature.sequence_name]
results[idx]["features"].extend(fg_json)
else:
fg_json = feature.fg_json()
results.extend(fg_json)
return {"features": results}
def create_feature_configs(
features: List[BaseFeature], asset_dir: Optional[str] = None
) -> List[FeatureConfig]:
"""Create feature configs for features."""
results = OrderedDict()
for feature in features:
feature = _copy_assets(feature, asset_dir)
if feature.is_grouped_sequence:
# pyre-ignore [16]
if feature.sequence_name not in results:
results[feature.sequence_name] = FeatureConfig(
sequence_feature=SequenceFeature(
sequence_name=feature.sequence_name,
sequence_length=feature.sequence_length, # pyre-ignore [16]
sequence_delim=feature.sequence_delim, # pyre-ignore [16]
sequence_pk=feature.sequence_pk, # pyre-ignore [16]
)
)
results[feature.sequence_name].sequence_feature.features.append(
feature.feature_config
)
else:
results[feature.name] = feature.feature_config
return list(results.values())