easy_rec/python/tools/add_boundaries_to_config.py (52 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import json import logging import os import sys import common_io import tensorflow as tf from easy_rec.python.utils import config_util from easy_rec.python.utils import io_util if tf.__version__ >= '2.0': tf = tf.compat.v1 logging.basicConfig( format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s', level=logging.INFO) tf.app.flags.DEFINE_string('template_config_path', None, 'Path to template pipeline config ' 'file.') tf.app.flags.DEFINE_string('output_config_path', None, 'Path to output pipeline config ' 'file.') tf.app.flags.DEFINE_string('tables', '', 'quantile binning table') FLAGS = tf.app.flags.FLAGS def main(argv): pipeline_config = config_util.get_configs_from_pipeline_file( FLAGS.template_config_path) feature_boundaries_info = {} reader = common_io.table.TableReader( FLAGS.tables, selected_cols='feature,json') while True: try: record = reader.read() raw_info = json.loads(record[0][1]) bin_info = [] for info in raw_info['bin']['norm'][:-1]: split_point = float(info['value'].split(',')[1][:-1]) bin_info.append(split_point) feature_boundaries_info[record[0][0]] = bin_info except common_io.exception.OutOfRangeException: reader.close() break logging.info('feature boundaries: %s' % feature_boundaries_info) for feature_config in pipeline_config.feature_configs: feature_name = feature_config.input_names[0] if feature_name in feature_boundaries_info: feature_config.feature_type = feature_config.RawFeature feature_config.hash_bucket_size = 0 feature_config.boundaries.extend(feature_boundaries_info[feature_name]) logging.info('edited %s' % feature_name) config_dir, config_name = os.path.split(FLAGS.output_config_path) config_util.save_pipeline_config(pipeline_config, config_dir, config_name) if __name__ == '__main__': sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv) tf.app.run()