def __init__()

in easy_rec/python/input/datahub_input.py [0:0]


  def __init__(self,
               data_config,
               feature_config,
               datahub_config,
               task_index=0,
               task_num=1,
               check_mode=False,
               pipeline_config=None):
    super(DataHubInput,
          self).__init__(data_config, feature_config, '', task_index, task_num,
                         check_mode, pipeline_config)
    if DataHub is None:
      logging.error('please install datahub: ',
                    'pip install pydatahub ;Python 3.6 recommended')
    try:
      self._num_epoch = 0
      self._datahub_config = datahub_config
      if self._datahub_config is not None:
        akId = self._datahub_config.akId
        akSecret = self._datahub_config.akSecret
        endpoint = self._datahub_config.endpoint
        if not isinstance(akId, str):
          akId = akId.encode('utf-8')
          akSecret = akSecret.encode('utf-8')
          endpoint = endpoint.encode('utf-8')
        self._datahub = DataHub(akId, akSecret, endpoint)
      else:
        self._datahub = None
    except Exception as ex:
      logging.info('exception in init datahub: %s' % str(ex))
      pass
    self._offset_dict = {}
    if datahub_config:
      shard_result = self._datahub.list_shard(self._datahub_config.project,
                                              self._datahub_config.topic)
      shards = shard_result.shards
      self._all_shards = shards
      self._shards = [
          shards[i] for i in range(len(shards)) if (i % task_num) == task_index
      ]
      logging.info('all shards: %s' % str(self._shards))

      offset_type = datahub_config.WhichOneof('offset')
      if offset_type == 'offset_time':
        ts = parse_time(datahub_config.offset_time) * 1000
        for x in self._all_shards:
          ks = str(x.shard_id)
          cursor_result = self._datahub.get_cursor(self._datahub_config.project,
                                                   self._datahub_config.topic,
                                                   ks, CursorType.SYSTEM_TIME,
                                                   ts)
          logging.info('shard[%s] cursor = %s' % (ks, cursor_result))
          self._offset_dict[ks] = cursor_result.cursor
      elif offset_type == 'offset_info':
        self._offset_dict = json.loads(self._datahub_config.offset_info)
      else:
        self._offset_dict = {}

      self._dh_field_names = []
      self._dh_field_types = []
      topic_info = self._datahub.get_topic(
          project_name=self._datahub_config.project,
          topic_name=self._datahub_config.topic)
      for field in topic_info.record_schema.field_list:
        self._dh_field_names.append(field.name)
        self._dh_field_types.append(field.type.value)

      assert len(
          self._feature_fields) > 0, 'data_config.feature_fields are not set.'

      for x in self._feature_fields:
        assert x in self._dh_field_names, 'feature_field[%s] is not in datahub' % x

      # feature column ids in datahub schema
      self._dh_fea_ids = [
          self._dh_field_names.index(x) for x in self._feature_fields
      ]

      for x in self._label_fields:
        assert x in self._dh_field_names, 'label_field[%s] is not in datahub' % x

      if self._data_config.HasField('sample_weight'):
        x = self._data_config.sample_weight
        assert x in self._dh_field_names, 'sample_weight[%s] is not in datahub' % x

      self._read_cnt = 32

      if len(self._dh_fea_ids) > 1:
        self._filter_fea_func = lambda record: ''.join(
            [record.values[x]
             for x in self._dh_fea_ids]).split(chr(2))[1] == '-1024'
      else:
        dh_fea_id = self._dh_fea_ids[0]
        self._filter_fea_func = lambda record: record.values[dh_fea_id].split(
            self._data_config.separator)[1] == '-1024'