easy_rec/python/tools/add_feature_info_to_config.py (126 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json import logging import os import sys import tensorflow as tf from easy_rec.python.utils import config_util from easy_rec.python.utils import io_util from easy_rec.python.utils.hive_utils import HiveUtils if tf.__version__ >= '2.0': tf = tf.compat.v1 logging.basicConfig( format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s', level=logging.INFO) tf.app.flags.DEFINE_string('template_config_path', None, 'Path to template pipeline config ' 'file.') tf.app.flags.DEFINE_string('output_config_path', None, 'Path to output pipeline config ' 'file.') tf.app.flags.DEFINE_string('config_table', '', 'config table') FLAGS = tf.app.flags.FLAGS def main(argv): pipeline_config = config_util.get_configs_from_pipeline_file( FLAGS.template_config_path) sels = 'feature,feature_info,message' feature_info_map = {} drop_feature_names = [] if pipeline_config.WhichOneof('train_path') == 'hive_train_input': hive_util = HiveUtils( data_config=pipeline_config.data_config, hive_config=pipeline_config.hive_train_input, selected_cols=sels, record_defaults=['', '', '']) reader = hive_util.hive_read_line(FLAGS.config_table) for record in reader: feature_name = record[0][0] feature_info_map[feature_name] = json.loads(record[0][1]) if 'DROP IT' in record[0][2]: drop_feature_names.append(feature_name) else: import common_io reader = common_io.table.TableReader(FLAGS.config_table, selected_cols=sels) while True: try: record = reader.read() feature_name = record[0][0] feature_info_map[feature_name] = json.loads(record[0][1]) if 'DROP IT' in record[0][2]: drop_feature_names.append(feature_name) except common_io.exception.OutOfRangeException: reader.close() break feature_configs = config_util.get_compatible_feature_configs(pipeline_config) if drop_feature_names: tmp_feature_configs = feature_configs[:] for fea_cfg in tmp_feature_configs: fea_name = fea_cfg.input_names[0] if fea_name in drop_feature_names: feature_configs.remove(fea_cfg) for feature_config in feature_configs: feature_name = feature_config.input_names[0] if feature_name in feature_info_map: logging.info('edited %s' % feature_name) feature_config.embedding_dim = int( feature_info_map[feature_name]['embedding_dim']) logging.info('modify embedding_dim to %s' % feature_config.embedding_dim) if 'boundary' in feature_info_map[feature_name]: feature_config.ClearField('boundaries') feature_config.boundaries.extend( [float(i) for i in feature_info_map[feature_name]['boundary']]) logging.info('modify boundaries to %s' % feature_config.boundaries) elif 'hash_bucket_size' in feature_info_map[feature_name]: feature_config.hash_bucket_size = int( feature_info_map[feature_name]['hash_bucket_size']) logging.info('modify hash_bucket_size to %s' % feature_config.hash_bucket_size) # modify num_steps pipeline_config.train_config.num_steps = feature_info_map['__NUM_STEPS__'][ 'num_steps'] logging.info('modify num_steps to %s' % pipeline_config.train_config.num_steps) # modify decay_steps optimizer_configs = pipeline_config.train_config.optimizer_config for optimizer_config in optimizer_configs: optimizer = optimizer_config.WhichOneof('optimizer') optimizer = getattr(optimizer_config, optimizer) learning_rate = optimizer.learning_rate.WhichOneof('learning_rate') learning_rate = getattr(optimizer.learning_rate, learning_rate) if hasattr(learning_rate, 'decay_steps'): learning_rate.decay_steps = feature_info_map['__DECAY_STEPS__'][ 'decay_steps'] logging.info('modify decay_steps to %s' % learning_rate.decay_steps) for feature_group in pipeline_config.model_config.feature_groups: feature_names = feature_group.feature_names reserved_features = [] for feature_name in feature_names: if feature_name not in drop_feature_names: reserved_features.append(feature_name) else: logging.info('drop feature: %s' % feature_name) feature_group.ClearField('feature_names') feature_group.feature_names.extend(reserved_features) for sequence_feature in feature_group.sequence_features: seq_att_maps = sequence_feature.seq_att_map for seq_att in seq_att_maps: keys = seq_att.key reserved_keys = [] for key in keys: if key not in drop_feature_names: reserved_keys.append(key) else: logging.info('drop sequence feature key: %s' % key) seq_att.ClearField('key') seq_att.key.extend(reserved_keys) hist_seqs = seq_att.hist_seq reserved_hist_seqs = [] for hist_seq in hist_seqs: if hist_seq not in drop_feature_names: reserved_hist_seqs.append(hist_seq) else: logging.info('drop sequence feature hist_seq: %s' % hist_seq) seq_att.ClearField('hist_seq') seq_att.hist_seq.extend(reserved_hist_seqs) config_dir, config_name = os.path.split(FLAGS.output_config_path) config_util.save_pipeline_config(pipeline_config, config_dir, config_name) if __name__ == '__main__': sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) tf.app.run()