tzrec/features/lookup_feature.py (210 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import pyarrow as pa
from tzrec.datasets.utils import (
CROSS_NEG_DATA_GROUP,
DenseData,
ParsedData,
SparseData,
)
from tzrec.features.feature import (
MAX_HASH_BUCKET_SIZE,
BaseFeature,
FgMode,
_parse_fg_encoded_dense_feature_impl,
_parse_fg_encoded_sparse_feature_impl,
)
from tzrec.protos.feature_pb2 import FeatureConfig
class LookupFeature(BaseFeature):
"""LookupFeature class.
Args:
feature_config (FeatureConfig): a instance of feature config.
fg_mode (FgMode): input data fg mode.
fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE
"""
def __init__(
self,
feature_config: FeatureConfig,
fg_mode: FgMode = FgMode.FG_NONE,
fg_encoded_multival_sep: Optional[str] = None,
) -> None:
super().__init__(feature_config, fg_mode, fg_encoded_multival_sep)
@property
def name(self) -> str:
"""Feature name."""
return self.config.feature_name
# pyre-ignore [56]
@BaseFeature.is_neg.setter
def is_neg(self, value: bool) -> None:
"""Feature is negative sampled or not."""
self._is_neg = value
self._data_group = CROSS_NEG_DATA_GROUP
@property
def value_dim(self) -> int:
"""Fg value dimension of the feature."""
if self.config.HasField("value_dim"):
return self.config.value_dim
else:
return 1
@property
def output_dim(self) -> int:
"""Output dimension of the feature after embedding."""
if self.has_embedding:
return self.config.embedding_dim
else:
return max(self.config.value_dim, 1)
@property
def is_sparse(self) -> bool:
"""Feature is sparse or dense."""
if self._is_sparse is None:
self._is_sparse = (
self.config.HasField("zch")
or self.config.HasField("hash_bucket_size")
or self.config.HasField("num_buckets")
or len(self.vocab_list) > 0
or len(self.vocab_dict) > 0
or len(self.vocab_file) > 0
or len(self.config.boundaries) > 0
)
return self._is_sparse
@property
def num_embeddings(self) -> int:
"""Get embedding row count."""
if self.config.HasField("zch"):
num_embeddings = self.config.zch.zch_size
elif self.config.HasField("hash_bucket_size"):
num_embeddings = self.config.hash_bucket_size
elif self.config.HasField("num_buckets"):
num_embeddings = self.config.num_buckets
elif len(self.vocab_list) > 0:
num_embeddings = len(self.vocab_list)
elif len(self.vocab_dict) > 0:
num_embeddings = max(list(self.vocab_dict.values())) + 1
elif len(self.vocab_file) > 0:
self.init_fg()
num_embeddings = self._fg_op.vocab_list_size()
else:
num_embeddings = len(self.config.boundaries) + 1
return num_embeddings
@property
def _dense_emb_type(self) -> Optional[str]:
return self.config.WhichOneof("dense_emb")
def _build_side_inputs(self) -> Optional[List[Tuple[str, str]]]:
"""Input field names with side."""
if self.config.HasField("map") and self.config.HasField("key"):
return [tuple(x.split(":")) for x in [self.config.map, self.config.key]]
else:
return None
def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData:
"""Parse input data for the feature impl.
Args:
input_data (dict): raw input feature data.
Return:
parsed feature data.
"""
if self.fg_mode == FgMode.FG_NONE:
# input feature is already lookuped
feat = input_data[self.name]
if self.is_sparse:
parsed_feat = _parse_fg_encoded_sparse_feature_impl(
self.name, feat, **self._fg_encoded_kwargs
)
else:
parsed_feat = _parse_fg_encoded_dense_feature_impl(
self.name, feat, **self._fg_encoded_kwargs
)
elif self.fg_mode == FgMode.FG_NORMAL:
input_feats = []
for name in self.inputs:
x = input_data[name]
if pa.types.is_list(x.type):
x = x.fill_null([])
elif pa.types.is_map(x.type):
x = x.fill_null({})
input_feats.append(x.tolist())
if self.config.value_dim > 1:
fgout, status = self._fg_op.process(dict(zip(self.inputs, input_feats)))
assert status.ok(), status.message()
if self.is_sparse:
values = np.asarray(fgout[self.name].values, np.int64)
lengths = np.asarray(fgout[self.name].lengths, np.int32)
parsed_feat = SparseData(
name=self.name, values=values, lengths=lengths
)
else:
values = fgout[self.name].dense_values
parsed_feat = DenseData(name=self.name, values=values)
else:
if self.is_sparse:
values, lengths = self._fg_op.to_bucketized_jagged_tensor(
*input_feats
)
parsed_feat = SparseData(
name=self.name, values=values, lengths=lengths
)
else:
values = self._fg_op.transform(*input_feats)
parsed_feat = DenseData(name=self.name, values=values)
else:
raise ValueError(
f"fg_mode: {self.fg_mode} is not supported without fg handler."
)
return parsed_feat
def fg_json(self) -> List[Dict[str, Any]]:
"""Get fg json config."""
fg_cfg = {
"feature_type": "lookup_feature",
"feature_name": self.name,
"map": self.config.map,
"key": self.config.key,
"default_value": self.config.default_value,
"value_type": "float",
"needDiscrete": self.config.need_discrete,
"needKey": self.config.need_key,
"combiner": self.config.combiner.lower(),
}
raw_fg_cfg = None
if self.config.value_dim > 1:
fg_cfg["feature_name"] = self.name + "__lookup"
fg_cfg["default_value"] = ""
fg_cfg["value_type"] = "string"
fg_cfg["value_dim"] = 1
fg_cfg["needDiscrete"] = True
fg_cfg["combiner"] = ""
fg_cfg["stub_type"] = True
raw_fg_cfg = {
"feature_type": "raw_feature",
"feature_name": self.name,
"default_value": self.config.default_value,
"expression": "feature:" + fg_cfg["feature_name"],
"separator": self.config.value_separator,
"value_dim": self.config.value_dim,
"value_type": "float",
}
if self.config.HasField("normalizer"):
raw_fg_cfg["normalizer"] = self.config.normalizer
if len(self.config.boundaries) > 0:
raw_fg_cfg["boundaries"] = list(self.config.boundaries)
else:
if self.config.separator != "\x1d":
fg_cfg["separator"] = self.config.separator
if self.config.HasField("normalizer"):
fg_cfg["normalizer"] = self.config.normalizer
if self.config.HasField("zch"):
fg_cfg["hash_bucket_size"] = MAX_HASH_BUCKET_SIZE
fg_cfg["value_type"] = "string"
elif self.config.HasField("hash_bucket_size"):
fg_cfg["hash_bucket_size"] = self.config.hash_bucket_size
fg_cfg["value_type"] = "string"
elif self.config.HasField("num_buckets"):
fg_cfg["num_buckets"] = self.config.num_buckets
fg_cfg["value_type"] = "string"
elif len(self.vocab_list) > 0:
fg_cfg["vocab_list"] = self.vocab_list
fg_cfg["default_bucketize_value"] = self.default_bucketize_value
fg_cfg["value_type"] = "string"
elif len(self.vocab_dict) > 0:
fg_cfg["vocab_dict"] = self.vocab_dict
fg_cfg["default_bucketize_value"] = self.default_bucketize_value
fg_cfg["value_type"] = "string"
elif len(self.vocab_file) > 0:
fg_cfg["vocab_file"] = self.vocab_file
fg_cfg["default_bucketize_value"] = self.default_bucketize_value
fg_cfg["value_type"] = "string"
elif len(self.config.boundaries) > 0:
fg_cfg["boundaries"] = list(self.config.boundaries)
if self.config.HasField("fg_value_type"):
fg_cfg["value_type"] = self.config.fg_value_type
if fg_cfg["value_type"] == "string":
fg_cfg["needDiscrete"] = True
if fg_cfg["needDiscrete"]:
fg_cfg["combiner"] = ""
if fg_cfg["combiner"] == "":
fg_cfg["value_dim"] = self.value_dim
fg_cfgs = [fg_cfg]
if raw_fg_cfg is not None:
fg_cfgs.append(raw_fg_cfg)
return fg_cfgs
def assets(self) -> Dict[str, str]:
"""Asset file paths."""
assets = {}
if len(self.vocab_file) > 0:
assets["vocab_file"] = self.vocab_file
return assets