easy_rec/python/input/datahub_input.py (262 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import logging
import traceback
import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile
from easy_rec.python.input.input import Input
from easy_rec.python.utils import odps_util
from easy_rec.python.utils.config_util import parse_time
try:
import common_io
except Exception:
common_io = None
try:
from datahub import DataHub
from datahub.exceptions import DatahubException
from datahub.models import RecordType
from datahub.models import CursorType
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
logging.getLogger('datahub.account').setLevel(logging.INFO)
except Exception:
logging.warning(
'DataHub is not installed[%s]. You can install it by: pip install pydatahub'
% traceback.format_exc())
DataHub = None
class DataHubInput(Input):
"""DataHubInput is used for online train."""
def __init__(self,
data_config,
feature_config,
datahub_config,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None):
super(DataHubInput,
self).__init__(data_config, feature_config, '', task_index, task_num,
check_mode, pipeline_config)
if DataHub is None:
logging.error('please install datahub: ',
'pip install pydatahub ;Python 3.6 recommended')
try:
self._num_epoch = 0
self._datahub_config = datahub_config
if self._datahub_config is not None:
akId = self._datahub_config.akId
akSecret = self._datahub_config.akSecret
endpoint = self._datahub_config.endpoint
if not isinstance(akId, str):
akId = akId.encode('utf-8')
akSecret = akSecret.encode('utf-8')
endpoint = endpoint.encode('utf-8')
self._datahub = DataHub(akId, akSecret, endpoint)
else:
self._datahub = None
except Exception as ex:
logging.info('exception in init datahub: %s' % str(ex))
pass
self._offset_dict = {}
if datahub_config:
shard_result = self._datahub.list_shard(self._datahub_config.project,
self._datahub_config.topic)
shards = shard_result.shards
self._all_shards = shards
self._shards = [
shards[i] for i in range(len(shards)) if (i % task_num) == task_index
]
logging.info('all shards: %s' % str(self._shards))
offset_type = datahub_config.WhichOneof('offset')
if offset_type == 'offset_time':
ts = parse_time(datahub_config.offset_time) * 1000
for x in self._all_shards:
ks = str(x.shard_id)
cursor_result = self._datahub.get_cursor(self._datahub_config.project,
self._datahub_config.topic,
ks, CursorType.SYSTEM_TIME,
ts)
logging.info('shard[%s] cursor = %s' % (ks, cursor_result))
self._offset_dict[ks] = cursor_result.cursor
elif offset_type == 'offset_info':
self._offset_dict = json.loads(self._datahub_config.offset_info)
else:
self._offset_dict = {}
self._dh_field_names = []
self._dh_field_types = []
topic_info = self._datahub.get_topic(
project_name=self._datahub_config.project,
topic_name=self._datahub_config.topic)
for field in topic_info.record_schema.field_list:
self._dh_field_names.append(field.name)
self._dh_field_types.append(field.type.value)
assert len(
self._feature_fields) > 0, 'data_config.feature_fields are not set.'
for x in self._feature_fields:
assert x in self._dh_field_names, 'feature_field[%s] is not in datahub' % x
# feature column ids in datahub schema
self._dh_fea_ids = [
self._dh_field_names.index(x) for x in self._feature_fields
]
for x in self._label_fields:
assert x in self._dh_field_names, 'label_field[%s] is not in datahub' % x
if self._data_config.HasField('sample_weight'):
x = self._data_config.sample_weight
assert x in self._dh_field_names, 'sample_weight[%s] is not in datahub' % x
self._read_cnt = 32
if len(self._dh_fea_ids) > 1:
self._filter_fea_func = lambda record: ''.join(
[record.values[x]
for x in self._dh_fea_ids]).split(chr(2))[1] == '-1024'
else:
dh_fea_id = self._dh_fea_ids[0]
self._filter_fea_func = lambda record: record.values[dh_fea_id].split(
self._data_config.separator)[1] == '-1024'
def _parse_record(self, *fields):
field_dict = {}
fields = list(fields)
def _dump_offsets():
all_offsets = {
x.shard_id: self._offset_dict[x.shard_id]
for x in self._shards
if x.shard_id in self._offset_dict
}
return json.dumps(all_offsets)
field_dict[Input.DATA_OFFSET] = tf.py_func(_dump_offsets, [], dtypes.string)
for x in self._label_fields:
dh_id = self._dh_field_names.index(x)
field_dict[x] = fields[dh_id]
feature_inputs = self.get_feature_input_fields()
# only for features, labels and sample_weight excluded
record_types = [
t for x, t in zip(self._input_fields, self._input_field_types)
if x in feature_inputs
]
feature_num = len(record_types)
feature_fields = [
fields[self._dh_field_names.index(x)] for x in self._feature_fields
]
feature = feature_fields[0]
for fea_id in range(1, len(feature_fields)):
feature = feature + self._data_config.separator + feature_fields[fea_id]
feature = tf.string_split(
feature, self._data_config.separator, skip_empty=False)
fields = tf.reshape(feature.values, [-1, feature_num])
for fid in range(feature_num):
field_dict[feature_inputs[fid]] = fields[:, fid]
return field_dict
def _preprocess(self, field_dict):
output_dict = super(DataHubInput, self)._preprocess(field_dict)
# append offset fields
if Input.DATA_OFFSET in field_dict:
output_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
# for _get_features to include DATA_OFFSET
if Input.DATA_OFFSET not in self._appended_fields:
self._appended_fields.append(Input.DATA_OFFSET)
return output_dict
def restore(self, checkpoint_path):
if checkpoint_path is None:
return
offset_path = checkpoint_path + '.offset'
if not gfile.Exists(offset_path):
return
logging.info('will restore datahub offset from %s' % offset_path)
with gfile.GFile(offset_path, 'r') as fin:
offset_dict = json.load(fin)
for k in offset_dict:
v = offset_dict[k]
ks = str(k)
if ks not in self._offset_dict or v > self._offset_dict[ks]:
self._offset_dict[ks] = v
def _is_data_empty(self, record):
is_empty = True
for fid in self._dh_fea_ids:
if record.values[fid] is not None and len(record.values[fid]) > 0:
is_empty = False
break
return is_empty
def _dump_record(self, record):
feas = []
for fid in range(len(record.values)):
if fid not in self._dh_fea_ids:
feas.append(self._dh_field_names[fid] + ':' + str(record.values[fid]))
return ';'.join(feas)
def _datahub_generator(self):
logging.info('start epoch[%d]' % self._num_epoch)
self._num_epoch += 1
try:
self._datahub.wait_shards_ready(self._datahub_config.project,
self._datahub_config.topic)
topic_result = self._datahub.get_topic(self._datahub_config.project,
self._datahub_config.topic)
if topic_result.record_type != RecordType.TUPLE:
logging.error('datahub topic type(%s) illegal' %
str(topic_result.record_type))
record_schema = topic_result.record_schema
tid = 0
while True:
shard_id = self._shards[tid].shard_id
tid += 1
if tid >= len(self._shards):
tid = 0
if shard_id not in self._offset_dict:
cursor_result = self._datahub.get_cursor(self._datahub_config.project,
self._datahub_config.topic,
shard_id, CursorType.OLDEST)
cursor = cursor_result.cursor
else:
cursor = self._offset_dict[shard_id]
get_result = self._datahub.get_tuple_records(
self._datahub_config.project, self._datahub_config.topic, shard_id,
record_schema, cursor, self._read_cnt)
count = get_result.record_count
if count == 0:
continue
for row_id, record in enumerate(get_result.records):
if self._is_data_empty(record):
logging.warning('skip empty data record: %s' %
self._dump_record(record))
continue
if self._filter_fea_func is not None:
if self._filter_fea_func(record):
logging.warning('filter data record: %s' %
self._dump_record(record))
continue
yield tuple(list(record.values))
if shard_id not in self._offset_dict or get_result.next_cursor > self._offset_dict[
shard_id]:
self._offset_dict[shard_id] = get_result.next_cursor
except DatahubException as ex:
logging.error('DatahubException: %s' % str(ex))
def _build(self, mode, params):
if mode == tf.estimator.ModeKeys.TRAIN:
assert self._datahub is not None, 'datahub_train_input is not set'
elif mode == tf.estimator.ModeKeys.EVAL:
assert self._datahub is not None, 'datahub_eval_input is not set'
# get input types
list_types = [
odps_util.odps_type_2_tf_type(x) for x in self._dh_field_types
]
list_types = tuple(list_types)
list_shapes = [
tf.TensorShape([]) for x in range(0, len(self._dh_field_types))
]
list_shapes = tuple(list_shapes)
# read datahub
dataset = tf.data.Dataset.from_generator(
self._datahub_generator,
output_types=list_types,
output_shapes=list_shapes)
if mode == tf.estimator.ModeKeys.TRAIN:
if self._data_config.shuffle:
dataset = dataset.shuffle(
self._data_config.shuffle_buffer_size,
seed=2020,
reshuffle_each_iteration=True)
dataset = dataset.batch(self._data_config.batch_size)
dataset = dataset.map(
self._parse_record,
num_parallel_calls=self._data_config.num_parallel_calls)
# preprocess is necessary to transform data
# so that they could be feed into FeatureColumns
dataset = dataset.map(
map_func=self._preprocess,
num_parallel_calls=self._data_config.num_parallel_calls)
dataset = dataset.prefetch(buffer_size=self._prefetch_size)
if mode != tf.estimator.ModeKeys.PREDICT:
dataset = dataset.map(lambda x:
(self._get_features(x), self._get_labels(x)))
else:
dataset = dataset.map(lambda x: (self._get_features(x)))
return dataset