easy_rec/python/inference/hive_predictor.py (140 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import time import tensorflow as tf from tensorflow.python.platform import gfile from easy_rec.python.inference.predictor import Predictor from easy_rec.python.protos.dataset_pb2 import DatasetConfig from easy_rec.python.utils import tf_utils from easy_rec.python.utils.hive_utils import HiveUtils if tf.__version__ >= '2.0': tf = tf.compat.v1 class HivePredictor(Predictor): def __init__(self, model_path, data_config, hive_config, fg_json_path=None, profiling_file=None, output_sep=chr(1), all_cols=None, all_col_types=None): super(HivePredictor, self).__init__(model_path, profiling_file, fg_json_path) self._data_config = data_config self._hive_config = hive_config self._output_sep = output_sep input_type = DatasetConfig.InputType.Name(data_config.input_type).lower() if 'rtp' in input_type: self._is_rtp = True else: self._is_rtp = False self._all_cols = [x.strip() for x in all_cols if x != ''] self._all_col_types = [x.strip() for x in all_col_types if x != ''] self._record_defaults = [ self._get_defaults(col_name, col_type) for col_name, col_type in zip(self._all_cols, self._all_col_types) ] def _get_reserved_cols(self, reserved_cols): if reserved_cols == 'ALL_COLUMNS': reserved_cols = self._all_cols else: reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != ''] return reserved_cols def _parse_line(self, line): field_delim = self._data_config.rtp_separator if self._is_rtp else self._data_config.separator fields = tf.decode_csv( line, field_delim=field_delim, record_defaults=self._record_defaults, name='decode_csv') inputs = {self._all_cols[x]: fields[x] for x in range(len(fields))} return inputs def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num, slice_id): self._hive_util = HiveUtils( data_config=self._data_config, hive_config=self._hive_config) self._input_hdfs_path = self._hive_util.get_table_location(input_path) file_paths = tf.gfile.Glob(os.path.join(self._input_hdfs_path, '*')) assert len(file_paths) > 0, 'match no files with %s' % input_path dataset = tf.data.Dataset.from_tensor_slices(file_paths) parallel_num = min(num_parallel_calls, len(file_paths)) dataset = dataset.interleave( tf.data.TextLineDataset, cycle_length=parallel_num, num_parallel_calls=parallel_num) dataset = dataset.shard(slice_num, slice_id) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size=64) return dataset def get_table_info(self, output_path): partition_name, partition_val = None, None if len(output_path.split('/')) == 2: table_name, partition = output_path.split('/') partition_name, partition_val = partition.split('=') else: table_name = output_path return table_name, partition_name, partition_val def _get_writer(self, output_path, slice_id): table_name, partition_name, partition_val = self.get_table_info(output_path) is_exist = self._hive_util.is_table_or_partition_exist( table_name, partition_name, partition_val) assert not is_exist, '%s is already exists. Please drop it.' % output_path output_path = output_path.replace('.', '/') self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % ( self._hive_config.host, output_path) if not gfile.Exists(self._hdfs_path): gfile.MakeDirs(self._hdfs_path) res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id) table_writer = gfile.GFile(res_path, 'w') return table_writer def _write_lines(self, table_writer, outputs): outputs = '\n'.join( [self._output_sep.join([str(i) for i in output]) for output in outputs]) table_writer.write(outputs + '\n') def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs): reserve_vals = [outputs[x] for x in output_cols] + \ [all_vals[k] for k in reserved_cols] return reserve_vals def load_to_table(self, output_path, slice_num, slice_id): res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id) success_writer = gfile.GFile(res_path, 'w') success_writer.write('') success_writer.close() if slice_id != 0: return for id in range(slice_num): res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id) while not gfile.Exists(res_path): time.sleep(10) table_name, partition_name, partition_val = self.get_table_info(output_path) schema = '' for output_col_name in self._output_cols: tf_type = self._predictor_impl._outputs_map[output_col_name].dtype col_type = tf_utils.get_col_type(tf_type) schema += output_col_name + ' ' + col_type + ',' for output_col_name in self._reserved_cols: assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name idx = self._all_cols.index(output_col_name) output_col_types = self._all_col_types[idx] schema += output_col_name + ' ' + output_col_types + ',' schema = schema.rstrip(',') if partition_name and partition_val: sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \ (table_name, schema, partition_name) self._hive_util.run_sql(sql) sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \ (self._hdfs_path, table_name, partition_name, partition_val) self._hive_util.run_sql(sql) else: sql = 'create table if not exists %s (%s)' % \ (table_name, schema) self._hive_util.run_sql(sql) sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \ (self._hdfs_path, table_name) self._hive_util.run_sql(sql) @property def out_of_range_exception(self): return (tf.errors.OutOfRangeError)