easy_rec/python/inference/parquet_predictor_v2.py (123 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from __future__ import absolute_import from __future__ import division from __future__ import print_function import logging import os import numpy as np import pandas as pd import tensorflow as tf from tensorflow.python.platform import gfile from easy_rec.python.inference.predictor import Predictor from easy_rec.python.input.parquet_input_v2 import ParquetInputV2 from easy_rec.python.protos.dataset_pb2 import DatasetConfig from easy_rec.python.utils import config_util from easy_rec.python.utils import input_utils try: from tensorflow.python.framework.load_library import load_op_library import easy_rec load_embed_lib_path = os.path.join(easy_rec.ops_dir, 'libload_embed.so') load_embed_lib = load_op_library(load_embed_lib_path) except Exception as ex: logging.warning('load libload_embed.so failed: %s' % str(ex)) class ParquetPredictorV2(Predictor): def __init__(self, model_path, data_config, ds_vector_recall=False, fg_json_path=None, profiling_file=None, selected_cols=None, output_sep=chr(1), pipeline_config=None): super(ParquetPredictorV2, self).__init__(model_path, profiling_file, fg_json_path) self._output_sep = output_sep self._ds_vector_recall = ds_vector_recall input_type = DatasetConfig.InputType.Name(data_config.input_type).lower() self.pipeline_config = pipeline_config if 'rtp' in input_type: self._is_rtp = True self._input_sep = data_config.rtp_separator else: self._is_rtp = False self._input_sep = data_config.separator if selected_cols and not ds_vector_recall: self._selected_cols = [int(x) for x in selected_cols.split(',')] elif ds_vector_recall: self._selected_cols = selected_cols.split(',') else: self._selected_cols = None def _parse_line(self, line): out_dict = {} for key in line['feature']: out_dict[key] = line['feature'][key] if 'reserve' in line: out_dict['reserve'] = line['reserve'] # for key in line['reserve']: # out_dict[key] = line['reserve'][key] return out_dict def _get_reserved_cols(self, reserved_cols): # already parsed in _get_dataset return self._reserved_cols def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num, slice_id): feature_configs = config_util.get_compatible_feature_configs( self.pipeline_config) kwargs = {} if self._reserved_args is not None and len(self._reserved_args) > 0: if self._reserved_args == 'ALL_COLUMNS': parquet_file = gfile.Glob(input_path.split(',')[0])[0] # gfile not supported, read_parquet requires random access all_data = pd.read_parquet(parquet_file) all_cols = list(all_data.columns) kwargs['reserve_fields'] = all_cols self._all_fields = all_cols self._reserved_cols = all_cols kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file( all_cols, parquet_file) else: self._reserved_cols = [ x.strip() for x in self._reserved_args.split(',') if x.strip() != '' ] kwargs['reserve_fields'] = self._reserved_cols parquet_file = gfile.Glob(input_path.split(',')[0])[0] kwargs['reserve_types'] = input_utils.get_tf_type_from_parquet_file( self._reserved_cols, parquet_file) logging.info('reserve_fields=%s reserve_types=%s' % (','.join(self._reserved_cols), ','.join( [str(x) for x in kwargs['reserve_types']]))) else: self._reserved_cols = [] self.pipeline_config.data_config.batch_size = batch_size kwargs['is_predictor'] = True parquet_input = ParquetInputV2( self.pipeline_config.data_config, feature_configs, input_path, task_index=slice_id, task_num=slice_num, pipeline_config=self.pipeline_config, **kwargs) return parquet_input._build(tf.estimator.ModeKeys.PREDICT, {}) def _get_writer(self, output_path, slice_id): if not gfile.Exists(output_path): gfile.MakeDirs(output_path) res_path = os.path.join(output_path, 'part-%d.csv' % slice_id) table_writer = gfile.GFile(res_path, 'w') table_writer.write( self._output_sep.join(self._output_cols + self._reserved_cols) + '\n') return table_writer def _write_lines(self, table_writer, outputs): outputs = '\n'.join( [self._output_sep.join([str(i) for i in output]) for output in outputs]) table_writer.write(outputs + '\n') def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs): reserve_vals = [] for x in outputs: tmp_val = outputs[x] reserve_vals.append(tmp_val) for k in reserved_cols: tmp_val = all_vals['reserve'][k] if tmp_val.dtype == np.object: tmp_val = [x.decode('utf-8') for x in tmp_val] reserve_vals.append(tmp_val) return reserve_vals @property def out_of_range_exception(self): return (tf.errors.OutOfRangeError)