easy_rec/python/predict.py (145 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import logging
import os
import tensorflow as tf
from tensorflow.python.lib.io import file_io
from easy_rec.python.inference.csv_predictor import CSVPredictor
from easy_rec.python.inference.hive_predictor import HivePredictor
from easy_rec.python.inference.parquet_predictor import ParquetPredictor
from easy_rec.python.inference.parquet_predictor_v2 import ParquetPredictorV2
from easy_rec.python.main import predict
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.utils import config_util
from easy_rec.python.utils import numpy_utils
from easy_rec.python.utils.hive_utils import HiveUtils
from easy_rec.python.inference.hive_parquet_predictor import HiveParquetPredictor # NOQA
if tf.__version__ >= '2.0':
tf = tf.compat.v1
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
tf.app.flags.DEFINE_string('input_path', None, 'predict data path')
tf.app.flags.DEFINE_string('output_path', None, 'path to save predict result')
tf.app.flags.DEFINE_integer('batch_size', 1024, help='batch size')
# predict by checkpoint
tf.app.flags.DEFINE_string('pipeline_config_path', None,
'Path to pipeline config '
'file.')
tf.app.flags.DEFINE_string(
'checkpoint_path', None, 'checkpoint to be evaled '
' if not specified, use the latest checkpoint in '
'train_config.model_dir')
tf.app.flags.DEFINE_string('model_dir', None, help='will update the model_dir')
# predict by saved_model
tf.app.flags.DEFINE_string('saved_model_dir', None, help='save model dir')
tf.app.flags.DEFINE_string(
'reserved_cols', 'ALL_COLUMNS',
'columns to keep from input table, they are separated with ,')
tf.app.flags.DEFINE_string(
'output_cols', 'ALL_COLUMNS',
'output columns, such as: score float. multiple columns are separated by ,')
tf.app.flags.DEFINE_string('output_sep', chr(1),
'separator of predict result file')
tf.app.flags.DEFINE_string('selected_cols', None, '')
tf.app.flags.DEFINE_string('fg_json_path', '', '')
tf.app.flags.DEFINE_string('ds_vector_recall', '', '')
tf.app.flags.DEFINE_string('input_type', '', 'data_config.input_type')
FLAGS = tf.app.flags.FLAGS
input_class_map = {y: x for x, y in DatasetConfig.InputType.items()}
input_class_map_r = {x: y for x, y in DatasetConfig.InputType.items()}
def get_input_type(input_type, data_config):
if input_type:
return input_class_map[input_type]
return data_config.input_type
def main(argv):
if FLAGS.saved_model_dir:
logging.info('Predict by saved_model.')
if FLAGS.pipeline_config_path:
pipeline_config_path = FLAGS.pipeline_config_path
else:
pipeline_config_path = config_util.search_pipeline_config(
FLAGS.saved_model_dir)
pipeline_config = config_util.get_configs_from_pipeline_file(
pipeline_config_path, False)
data_config = pipeline_config.data_config
input_type = get_input_type(FLAGS.input_type, data_config)
if input_type in [data_config.HiveParquetInput, data_config.HiveInput]:
all_cols, all_col_types = HiveUtils(
data_config=pipeline_config.data_config,
hive_config=pipeline_config.hive_train_input).get_all_cols(
FLAGS.input_path)
if input_type == DatasetConfig.HiveParquetInput:
predictor = HiveParquetPredictor(
FLAGS.saved_model_dir,
pipeline_config.data_config,
fg_json_path=FLAGS.fg_json_path,
hive_config=pipeline_config.hive_train_input,
output_sep=FLAGS.output_sep,
all_cols=all_cols,
all_col_types=all_col_types)
else:
predictor = HivePredictor(
FLAGS.saved_model_dir,
pipeline_config.data_config,
fg_json_path=FLAGS.fg_json_path,
hive_config=pipeline_config.hive_train_input,
output_sep=FLAGS.output_sep,
all_cols=all_cols,
all_col_types=all_col_types)
elif input_type in [data_config.ParquetInput, data_config.ParquetInputV2]:
predictor_cls = ParquetPredictor
if input_type == data_config.ParquetInputV2:
predictor_cls = ParquetPredictorV2
predictor = predictor_cls(
FLAGS.saved_model_dir,
pipeline_config.data_config,
ds_vector_recall=FLAGS.ds_vector_recall,
fg_json_path=FLAGS.fg_json_path,
selected_cols=FLAGS.selected_cols,
output_sep=FLAGS.output_sep,
pipeline_config=pipeline_config)
elif input_type == data_config.CSVInput:
predictor = CSVPredictor(
FLAGS.saved_model_dir,
pipeline_config.data_config,
ds_vector_recall=FLAGS.ds_vector_recall,
fg_json_path=FLAGS.fg_json_path,
selected_cols=FLAGS.selected_cols,
output_sep=FLAGS.output_sep)
else:
assert False, 'invalid input type: %s' % input_class_map_r[input_type]
logging.info('input_path = %s, output_path = %s' %
(FLAGS.input_path, FLAGS.output_path))
if 'TF_CONFIG' in os.environ:
tf_config = json.loads(os.environ['TF_CONFIG'])
worker_num = len(tf_config['cluster']['worker'])
task_index = tf_config['task']['index']
else:
worker_num = 1
task_index = 0
predictor.predict_impl(
FLAGS.input_path,
FLAGS.output_path,
reserved_cols=FLAGS.reserved_cols,
output_cols=FLAGS.output_cols,
batch_size=FLAGS.batch_size,
slice_id=task_index,
slice_num=worker_num)
else:
logging.info('Predict by checkpoint_path.')
assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
if FLAGS.model_dir:
pipeline_config_path = os.path.join(FLAGS.model_dir, 'pipeline.config')
if file_io.file_exists(pipeline_config_path):
logging.info('update pipeline_config_path to %s' % pipeline_config_path)
else:
pipeline_config_path = FLAGS.pipeline_config_path
else:
pipeline_config_path = FLAGS.pipeline_config_path
pred_result = predict(pipeline_config_path, FLAGS.checkpoint_path,
FLAGS.input_path)
if FLAGS.output_path is not None:
logging.info('will save predict result to %s' % FLAGS.output_path)
with tf.gfile.GFile(FLAGS.output_path, 'wb') as fout:
for k in pred_result:
fout.write(json.dumps(k, cls=numpy_utils.NumpyEncoder) + '\n')
if __name__ == '__main__':
tf.app.run()