easy_rec/python/tools/pre_check.py (87 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import logging
import os
import sys
import tensorflow as tf
from easy_rec.python.input.input import Input
from easy_rec.python.utils import config_util
from easy_rec.python.utils import fg_util
from easy_rec.python.utils import io_util
from easy_rec.python.utils.check_utils import check_env_and_input_path
from easy_rec.python.utils.check_utils import check_sequence
if tf.__version__ >= '2.0':
tf = tf.compat.v1
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
tf.app.flags.DEFINE_string('pipeline_config_path', None,
'Path to pipeline config '
'file.')
tf.app.flags.DEFINE_multi_string(
'data_input_path', None, help='data input path')
FLAGS = tf.app.flags.FLAGS
def _get_input_fn(data_config,
feature_configs,
data_path=None,
export_config=None):
"""Build estimator input function.
Args:
data_config: dataset config
feature_configs: FeatureConfig
data_path: input_data_path
export_config: configuration for exporting models,
only used to build input_fn when exporting models
Returns:
subclass of Input
"""
input_class_map = {y: x for x, y in data_config.InputType.items()}
input_cls_name = input_class_map[data_config.input_type]
input_class = Input.create_class(input_cls_name)
if 'TF_CONFIG' in os.environ:
tf_config = json.loads(os.environ['TF_CONFIG'])
worker_num = len(tf_config['cluster']['worker'])
task_index = tf_config['task']['index']
else:
worker_num = 1
task_index = 0
input_obj = input_class(
data_config,
feature_configs,
data_path,
task_index=task_index,
task_num=worker_num,
check_mode=True)
input_fn = input_obj.create_input(export_config)
return input_fn
def loda_pipeline_config(pipeline_config_path):
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path, False)
if pipeline_config.fg_json_path:
fg_util.load_fg_json_to_config(pipeline_config)
config_util.auto_expand_share_feature_configs(pipeline_config)
return pipeline_config
def run_check(pipeline_config, input_path):
logging.info('data_input_path: %s' % input_path)
check_env_and_input_path(pipeline_config, input_path)
feature_configs = config_util.get_compatible_feature_configs(pipeline_config)
eval_input_fn = _get_input_fn(pipeline_config.data_config, feature_configs,
input_path)
eval_spec = tf.estimator.EvalSpec(
name='val',
input_fn=eval_input_fn,
steps=None,
throttle_secs=10,
exporters=[])
input_iter = eval_spec.input_fn(
mode=tf.estimator.ModeKeys.EVAL).make_one_shot_iterator()
with tf.Session() as sess:
try:
while (True):
input_feas, input_lbls = input_iter.get_next()
features = sess.run(input_feas)
check_sequence(pipeline_config, features)
except tf.errors.OutOfRangeError:
logging.info('pre-check finish...')
def main(argv):
assert FLAGS.pipeline_config_path, 'pipeline_config_path should not be empty when checking!'
pipeline_config = loda_pipeline_config(FLAGS.pipeline_config_path)
if FLAGS.data_input_path:
input_path = ','.join(FLAGS.data_input_path)
else:
assert pipeline_config.train_input_path or pipeline_config.eval_input_path, \
'input_path should not be empty when checking!'
input_path = pipeline_config.train_input_path + ',' + pipeline_config.eval_input_path
run_check(pipeline_config, input_path)
if __name__ == '__main__':
sys.argv = io_util.filter_unknown_args(FLAGS, sys.argv)
tf.app.run()