easy_rec/python/inference/hive_parquet_predictor.py (169 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.python.platform import gfile
from easy_rec.python.inference.predictor import Predictor
from easy_rec.python.protos.dataset_pb2 import DatasetConfig
from easy_rec.python.utils import tf_utils
from easy_rec.python.utils.hive_utils import HiveUtils
from easy_rec.python.utils.tf_utils import get_tf_type
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class HiveParquetPredictor(Predictor):
def __init__(self,
model_path,
data_config,
hive_config,
fg_json_path=None,
profiling_file=None,
output_sep=chr(1),
all_cols=None,
all_col_types=None):
super(HiveParquetPredictor, self).__init__(model_path, profiling_file,
fg_json_path)
self._data_config = data_config
self._hive_config = hive_config
self._output_sep = output_sep
input_type = DatasetConfig.InputType.Name(data_config.input_type).lower()
if 'rtp' in input_type:
self._is_rtp = True
else:
self._is_rtp = False
self._all_cols = [x.strip() for x in all_cols if x != '']
self._all_col_types = [x.strip() for x in all_col_types if x != '']
self._record_defaults = [
self._get_defaults(col_name, col_type)
for col_name, col_type in zip(self._all_cols, self._all_col_types)
]
def _get_reserved_cols(self, reserved_cols):
if reserved_cols == 'ALL_COLUMNS':
reserved_cols = self._all_cols
else:
reserved_cols = [x.strip() for x in reserved_cols.split(',') if x != '']
return reserved_cols
def _parse_line(self, *fields):
fields = list(fields)
field_dict = {self._all_cols[i]: fields[i] for i in range(len(fields))}
return field_dict
def _get_dataset(self, input_path, num_parallel_calls, batch_size, slice_num,
slice_id):
self._hive_util = HiveUtils(
data_config=self._data_config, hive_config=self._hive_config)
hdfs_path = self._hive_util.get_table_location(input_path)
self._input_hdfs_path = gfile.Glob(os.path.join(hdfs_path, '*'))
assert len(self._input_hdfs_path) > 0, 'match no files with %s' % input_path
list_type = []
input_field_type_map = {
x.input_name: x.input_type for x in self._data_config.input_fields
}
type_2_tftype = {
'string': tf.string,
'double': tf.double,
'float': tf.float32,
'bigint': tf.int32,
'boolean': tf.bool
}
for col_name, col_type in zip(self._all_cols, self._all_col_types):
if col_name in input_field_type_map:
list_type.append(get_tf_type(input_field_type_map[col_name]))
else:
list_type.append(type_2_tftype[col_type.lower()])
list_type = tuple(list_type)
list_shapes = [tf.TensorShape([None]) for x in range(0, len(list_type))]
list_shapes = tuple(list_shapes)
def parquet_read():
for input_path in self._input_hdfs_path:
if input_path.endswith('SUCCESS'):
continue
df = pd.read_parquet(input_path, engine='pyarrow')
df.replace('', np.nan, inplace=True)
df.replace('NULL', np.nan, inplace=True)
total_records_num = len(df)
for k, v in zip(self._all_cols, self._record_defaults):
df[k].fillna(v, inplace=True)
for start_idx in range(0, total_records_num, batch_size):
end_idx = min(total_records_num, start_idx + batch_size)
batch_data = df[start_idx:end_idx]
inputs = []
for k in self._all_cols:
inputs.append(batch_data[k].to_numpy())
yield tuple(inputs)
dataset = tf.data.Dataset.from_generator(
parquet_read, output_types=list_type, output_shapes=list_shapes)
dataset = dataset.shard(slice_num, slice_id)
dataset = dataset.prefetch(buffer_size=64)
return dataset
def get_table_info(self, output_path):
partition_name, partition_val = None, None
if len(output_path.split('/')) == 2:
table_name, partition = output_path.split('/')
partition_name, partition_val = partition.split('=')
else:
table_name = output_path
return table_name, partition_name, partition_val
def _get_writer(self, output_path, slice_id):
table_name, partition_name, partition_val = self.get_table_info(output_path)
is_exist = self._hive_util.is_table_or_partition_exist(
table_name, partition_name, partition_val)
assert not is_exist, '%s is already exists. Please drop it.' % output_path
output_path = output_path.replace('.', '/')
self._hdfs_path = 'hdfs://%s:9000/user/easy_rec/%s_tmp' % (
self._hive_config.host, output_path)
if not gfile.Exists(self._hdfs_path):
gfile.MakeDirs(self._hdfs_path)
res_path = os.path.join(self._hdfs_path, 'part-%d.csv' % slice_id)
table_writer = gfile.GFile(res_path, 'w')
return table_writer
def _write_lines(self, table_writer, outputs):
outputs = '\n'.join(
[self._output_sep.join([str(i) for i in output]) for output in outputs])
table_writer.write(outputs + '\n')
def _get_reserve_vals(self, reserved_cols, output_cols, all_vals, outputs):
reserve_vals = [outputs[x] for x in output_cols] + \
[all_vals[k] for k in reserved_cols]
return reserve_vals
def load_to_table(self, output_path, slice_num, slice_id):
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % slice_id)
success_writer = gfile.GFile(res_path, 'w')
success_writer.write('')
success_writer.close()
if slice_id != 0:
return
for id in range(slice_num):
res_path = os.path.join(self._hdfs_path, 'SUCCESS-%s' % id)
while not gfile.Exists(res_path):
time.sleep(10)
table_name, partition_name, partition_val = self.get_table_info(output_path)
schema = ''
for output_col_name in self._output_cols:
tf_type = self._predictor_impl._outputs_map[output_col_name].dtype
col_type = tf_utils.get_col_type(tf_type)
schema += output_col_name + ' ' + col_type + ','
for output_col_name in self._reserved_cols:
assert output_col_name in self._all_cols, 'Column: %s not exists.' % output_col_name
idx = self._all_cols.index(output_col_name)
output_col_types = self._all_col_types[idx]
schema += output_col_name + ' ' + output_col_types + ','
schema = schema.rstrip(',')
if partition_name and partition_val:
sql = 'create table if not exists %s (%s) PARTITIONED BY (%s string)' % \
(table_name, schema, partition_name)
self._hive_util.run_sql(sql)
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s PARTITION (%s=%s)" % \
(self._hdfs_path, table_name, partition_name, partition_val)
self._hive_util.run_sql(sql)
else:
sql = 'create table if not exists %s (%s)' % \
(table_name, schema)
self._hive_util.run_sql(sql)
sql = "LOAD DATA INPATH '%s/*' INTO TABLE %s" % \
(self._hdfs_path, table_name)
self._hive_util.run_sql(sql)
@property
def out_of_range_exception(self):
return (tf.errors.OutOfRangeError)