in easy_rec/python/input/datahub_input.py [0:0]
def __init__(self,
data_config,
feature_config,
datahub_config,
task_index=0,
task_num=1,
check_mode=False,
pipeline_config=None):
super(DataHubInput,
self).__init__(data_config, feature_config, '', task_index, task_num,
check_mode, pipeline_config)
if DataHub is None:
logging.error('please install datahub: ',
'pip install pydatahub ;Python 3.6 recommended')
try:
self._num_epoch = 0
self._datahub_config = datahub_config
if self._datahub_config is not None:
akId = self._datahub_config.akId
akSecret = self._datahub_config.akSecret
endpoint = self._datahub_config.endpoint
if not isinstance(akId, str):
akId = akId.encode('utf-8')
akSecret = akSecret.encode('utf-8')
endpoint = endpoint.encode('utf-8')
self._datahub = DataHub(akId, akSecret, endpoint)
else:
self._datahub = None
except Exception as ex:
logging.info('exception in init datahub: %s' % str(ex))
pass
self._offset_dict = {}
if datahub_config:
shard_result = self._datahub.list_shard(self._datahub_config.project,
self._datahub_config.topic)
shards = shard_result.shards
self._all_shards = shards
self._shards = [
shards[i] for i in range(len(shards)) if (i % task_num) == task_index
]
logging.info('all shards: %s' % str(self._shards))
offset_type = datahub_config.WhichOneof('offset')
if offset_type == 'offset_time':
ts = parse_time(datahub_config.offset_time) * 1000
for x in self._all_shards:
ks = str(x.shard_id)
cursor_result = self._datahub.get_cursor(self._datahub_config.project,
self._datahub_config.topic,
ks, CursorType.SYSTEM_TIME,
ts)
logging.info('shard[%s] cursor = %s' % (ks, cursor_result))
self._offset_dict[ks] = cursor_result.cursor
elif offset_type == 'offset_info':
self._offset_dict = json.loads(self._datahub_config.offset_info)
else:
self._offset_dict = {}
self._dh_field_names = []
self._dh_field_types = []
topic_info = self._datahub.get_topic(
project_name=self._datahub_config.project,
topic_name=self._datahub_config.topic)
for field in topic_info.record_schema.field_list:
self._dh_field_names.append(field.name)
self._dh_field_types.append(field.type.value)
assert len(
self._feature_fields) > 0, 'data_config.feature_fields are not set.'
for x in self._feature_fields:
assert x in self._dh_field_names, 'feature_field[%s] is not in datahub' % x
# feature column ids in datahub schema
self._dh_fea_ids = [
self._dh_field_names.index(x) for x in self._feature_fields
]
for x in self._label_fields:
assert x in self._dh_field_names, 'label_field[%s] is not in datahub' % x
if self._data_config.HasField('sample_weight'):
x = self._data_config.sample_weight
assert x in self._dh_field_names, 'sample_weight[%s] is not in datahub' % x
self._read_cnt = 32
if len(self._dh_fea_ids) > 1:
self._filter_fea_func = lambda record: ''.join(
[record.values[x]
for x in self._dh_fea_ids]).split(chr(2))[1] == '-1024'
else:
dh_fea_id = self._dh_fea_ids[0]
self._filter_fea_func = lambda record: record.values[dh_fea_id].split(
self._data_config.separator)[1] == '-1024'