def main()

in easy_rec/python/predict.py [0:0]


def main(argv):

  if FLAGS.saved_model_dir:
    logging.info('Predict by saved_model.')
    if FLAGS.pipeline_config_path:
      pipeline_config_path = FLAGS.pipeline_config_path
    else:
      pipeline_config_path = config_util.search_pipeline_config(
          FLAGS.saved_model_dir)
    pipeline_config = config_util.get_configs_from_pipeline_file(
        pipeline_config_path, False)
    data_config = pipeline_config.data_config
    input_type = get_input_type(FLAGS.input_type, data_config)
    if input_type in [data_config.HiveParquetInput, data_config.HiveInput]:
      all_cols, all_col_types = HiveUtils(
          data_config=pipeline_config.data_config,
          hive_config=pipeline_config.hive_train_input).get_all_cols(
              FLAGS.input_path)
      if input_type == DatasetConfig.HiveParquetInput:
        predictor = HiveParquetPredictor(
            FLAGS.saved_model_dir,
            pipeline_config.data_config,
            fg_json_path=FLAGS.fg_json_path,
            hive_config=pipeline_config.hive_train_input,
            output_sep=FLAGS.output_sep,
            all_cols=all_cols,
            all_col_types=all_col_types)
      else:
        predictor = HivePredictor(
            FLAGS.saved_model_dir,
            pipeline_config.data_config,
            fg_json_path=FLAGS.fg_json_path,
            hive_config=pipeline_config.hive_train_input,
            output_sep=FLAGS.output_sep,
            all_cols=all_cols,
            all_col_types=all_col_types)
    elif input_type in [data_config.ParquetInput, data_config.ParquetInputV2]:
      predictor_cls = ParquetPredictor
      if input_type == data_config.ParquetInputV2:
        predictor_cls = ParquetPredictorV2
      predictor = predictor_cls(
          FLAGS.saved_model_dir,
          pipeline_config.data_config,
          ds_vector_recall=FLAGS.ds_vector_recall,
          fg_json_path=FLAGS.fg_json_path,
          selected_cols=FLAGS.selected_cols,
          output_sep=FLAGS.output_sep,
          pipeline_config=pipeline_config)
    elif input_type == data_config.CSVInput:
      predictor = CSVPredictor(
          FLAGS.saved_model_dir,
          pipeline_config.data_config,
          ds_vector_recall=FLAGS.ds_vector_recall,
          fg_json_path=FLAGS.fg_json_path,
          selected_cols=FLAGS.selected_cols,
          output_sep=FLAGS.output_sep)
    else:
      assert False, 'invalid input type: %s' % input_class_map_r[input_type]

    logging.info('input_path = %s, output_path = %s' %
                 (FLAGS.input_path, FLAGS.output_path))
    if 'TF_CONFIG' in os.environ:
      tf_config = json.loads(os.environ['TF_CONFIG'])
      worker_num = len(tf_config['cluster']['worker'])
      task_index = tf_config['task']['index']
    else:
      worker_num = 1
      task_index = 0
    predictor.predict_impl(
        FLAGS.input_path,
        FLAGS.output_path,
        reserved_cols=FLAGS.reserved_cols,
        output_cols=FLAGS.output_cols,
        batch_size=FLAGS.batch_size,
        slice_id=task_index,
        slice_num=worker_num)
  else:
    logging.info('Predict by checkpoint_path.')
    assert FLAGS.model_dir or FLAGS.pipeline_config_path, 'At least one of model_dir and pipeline_config_path exists.'
    if FLAGS.model_dir:
      pipeline_config_path = os.path.join(FLAGS.model_dir, 'pipeline.config')
      if file_io.file_exists(pipeline_config_path):
        logging.info('update pipeline_config_path to %s' % pipeline_config_path)
      else:
        pipeline_config_path = FLAGS.pipeline_config_path
    else:
      pipeline_config_path = FLAGS.pipeline_config_path

    pred_result = predict(pipeline_config_path, FLAGS.checkpoint_path,
                          FLAGS.input_path)
    if FLAGS.output_path is not None:
      logging.info('will save predict result to %s' % FLAGS.output_path)
      with tf.gfile.GFile(FLAGS.output_path, 'wb') as fout:
        for k in pred_result:
          fout.write(json.dumps(k, cls=numpy_utils.NumpyEncoder) + '\n')