tzrec/tools/add_feature_info_to_config.py (214 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from typing import Any, Dict, List, Tuple
from tzrec.datasets.dataset import create_reader
from tzrec.protos.pipeline_pb2 import EasyRecConfig
from tzrec.utils import config_util
from tzrec.utils.logging_util import logger
class AddFeatureInfoToConfig(object):
"""Add feature_info to config file.
Args:
template_model_config_path (str): template model config path.
model_config_path (str): model_config_path.
config_table_path (str): feature config info path.
reader_type (str): input path reader type.
odps_data_quota_name (str):maxcompute storage api/tunnel data quota name.
"""
def __init__(
self,
template_model_config_path: str,
model_config_path: str,
config_table_path: str,
reader_type: str,
odps_data_quota_name: str,
) -> None:
self.template_model_config_path = template_model_config_path
self.model_config_path = model_config_path
self.config_table_path = config_table_path
self.reader_type = reader_type
self.odps_data_quota_name = odps_data_quota_name
def _load_feature_info(self) -> Tuple[Dict[str, Any], List[str]]:
"""Load feature info for update config."""
feature_info_map = {}
drop_feature_names = []
sels = ["feature", "feature_info", "message"]
reader = create_reader(
self.config_table_path,
1,
selected_cols=sels,
reader_type=self.reader_type,
quota_name=self.odps_data_quota_name,
)
for data in reader.to_batches():
feature_names = data["feature"].tolist()
feature_infos = data["feature_info"].tolist()
messages = data["message"].tolist()
for record in zip(feature_names, feature_infos, messages):
feature_name = record[0]
feature_info_map[feature_name] = json.loads(record[1])
if record[2] is not None and "DROP IT" in record[2]:
drop_feature_names.append(feature_name)
return feature_info_map, drop_feature_names
def _drop_feature_config(
self, pipeline_config: EasyRecConfig, drop_feature_names: List[str]
) -> None:
"""Drop invalid feature config."""
feature_configs = pipeline_config.feature_configs
filter_feature_configs = []
if drop_feature_names:
for fea_cfg in feature_configs[:]:
oneof_feat_config = getattr(fea_cfg, fea_cfg.WhichOneof("feature"))
feat_cls_name = oneof_feat_config.__class__.__name__
if feat_cls_name == "SequenceFeature":
sequence_name = oneof_feat_config.sequence_name
sub_features = oneof_feat_config.features[:]
for sub_feat_config in sub_features:
feat_config = getattr(
sub_feat_config, sub_feat_config.WhichOneof("feature")
)
name = f"{sequence_name}__{feat_config.feature_name}"
if name in drop_feature_names:
oneof_feat_config.features.remove(sub_feat_config)
logger.info(f"drop sub sequence feature: {name}")
if len(oneof_feat_config.features) == 0:
feature_configs.remove(fea_cfg)
logger.info(f"drop sequence feature: {sequence_name}")
else:
filter_feature_configs.append(fea_cfg)
else:
if oneof_feat_config.feature_name in drop_feature_names:
feature_configs.remove(fea_cfg)
logger.info(f"drop feature: {oneof_feat_config.feature_name}")
else:
filter_feature_configs.append(fea_cfg)
pipeline_config.ClearField("feature_configs")
pipeline_config.feature_configs.extend(feature_configs)
def _update_feature_config(
self, pipeline_config: EasyRecConfig, feature_info_map: Dict[str, Any]
) -> List[str]:
"""Add feature info to feature config."""
feature_configs = pipeline_config.feature_configs
general_feature = []
for fea_cfg in feature_configs:
feature_config = getattr(fea_cfg, fea_cfg.WhichOneof("feature"))
feat_cls_name = feature_config.__class__.__name__
if feat_cls_name == "SequenceFeature":
sequence_name = feature_config.sequence_name
sub_features = feature_config.features
for sub_feat in sub_features:
sub_feat_config = getattr(sub_feat, sub_feat.WhichOneof("feature"))
feature_name = f"{sequence_name}__{sub_feat_config.feature_name}"
if feature_name in feature_info_map:
logger.info("edited %s" % feature_name)
sub_feat_config.embedding_dim = int(
feature_info_map[feature_name]["embedding_dim"]
)
if "boundary" in feature_info_map[feature_name]:
sub_feat_config.ClearField("boundaries")
sub_feat_config.boundaries.extend(
[
float(i)
for i in feature_info_map[feature_name]["boundary"]
]
)
elif "hash_bucket_size" in feature_info_map[feature_name]:
sub_feat_config.hash_bucket_size = int(
feature_info_map[feature_name]["hash_bucket_size"]
)
else:
logger.error(
f"please check: {feature_name}, this config no info..."
)
else:
feature_name = feature_config.feature_name
general_feature.append(feature_name)
if feature_name in feature_info_map:
logger.info("edited %s" % feature_name)
feature_config.embedding_dim = int(
feature_info_map[feature_name]["embedding_dim"]
)
if "boundary" in feature_info_map[feature_name]:
feature_config.ClearField("boundaries")
feature_config.boundaries.extend(
[
float(i)
for i in feature_info_map[feature_name]["boundary"]
]
)
elif "hash_bucket_size" in feature_info_map[feature_name]:
feature_config.hash_bucket_size = int(
feature_info_map[feature_name]["hash_bucket_size"]
)
else:
logger.error(
f"please check: {feature_name}, this config no info..."
)
return general_feature
def _update_feature_group(
self,
pipeline_config: EasyRecConfig,
drop_feature_names: List[str],
) -> None:
"""Drop feature name for feature group."""
for feature_group in pipeline_config.model_config.feature_groups:
feature_names = feature_group.feature_names
reserved_features = []
for feature_name in feature_names:
if feature_name not in drop_feature_names:
reserved_features.append(feature_name)
else:
logger.info("feature group drop feature: %s" % feature_name)
feature_group.ClearField("feature_names")
feature_group.feature_names.extend(reserved_features)
del_sequence_groups = []
for sequence_group in list(feature_group.sequence_groups):
reserved_features = []
for feature_name in sequence_group.feature_names:
if feature_name not in drop_feature_names:
reserved_features.append(feature_name)
else:
logger.info("sequence group drop feature: %s" % feature_name)
sequence_group.ClearField("feature_names")
sequence_group.feature_names.extend(reserved_features)
if len(reserved_features) == 0:
del_sequence_groups.append(sequence_group.group_name)
feature_group.sequence_groups.remove(sequence_group)
logger.info("drop sequence group: %s" % sequence_group.group_name)
for seq_encoded in list(feature_group.sequence_encoders):
seq_module = getattr(seq_encoded, seq_encoded.WhichOneof("seq_module"))
if seq_module.input in del_sequence_groups:
feature_group.sequence_encoders.remove(seq_encoded)
logger.info("drop sequence encoder: %s" % seq_module.input)
def build(self) -> None:
"""Build method."""
feature_info_map, drop_feature_names = self._load_feature_info()
pipeline_config = config_util.load_pipeline_config(
self.template_model_config_path
)
self._drop_feature_config(pipeline_config, drop_feature_names)
self._update_feature_config(pipeline_config, feature_info_map)
self._update_feature_group(pipeline_config, drop_feature_names)
config_util.save_message(pipeline_config, self.model_config_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--template_model_config_path",
type=str,
default=None,
help="template model config path",
)
parser.add_argument(
"--model_config_path", type=str, default=None, help="new model config path"
)
parser.add_argument(
"--config_table_path", type=str, default=None, help="feature config info path"
)
parser.add_argument(
"--reader_type",
type=str,
default="OdpsReader",
choices=["OdpsReader", "CsvReader", "ParquetReader"],
help="input path reader type.",
)
parser.add_argument(
"--odps_data_quota_name",
type=str,
default="pay-as-you-go",
help="maxcompute storage api/tunnel data quota name.",
)
args, extra_args = parser.parse_known_args()
fs = AddFeatureInfoToConfig(
args.template_model_config_path,
args.model_config_path,
args.config_table_path,
args.reader_type,
args.odps_data_quota_name,
)
fs.build()