def _preprocess()

in easy_rec/python/input/input.py [0:0]


  def _preprocess(self, field_dict):
    """Preprocess the feature columns.

    preprocess some feature columns, such as TagFeature or LookupFeature,
    it is expected to handle batch inputs and single input,
    it could be customized in subclasses

    Args:
      field_dict: string to tensor, tensors are dense,
          could be of shape [batch_size], [batch_size, None], or of shape []

    Returns:
      output_dict: some of the tensors are transformed into sparse tensors,
          such as input tensors of tag features and lookup features
    """
    parsed_dict = {}

    if self._sampler is not None and self._mode != tf.estimator.ModeKeys.PREDICT:
      if self._mode != tf.estimator.ModeKeys.TRAIN:
        self._sampler.set_eval_num_sample()
      sampler_type = self._data_config.WhichOneof('sampler')
      sampler_config = getattr(self._data_config, sampler_type)
      item_ids = field_dict[sampler_config.item_id_field]
      if sampler_type in ['negative_sampler', 'negative_sampler_in_memory']:
        sampled = self._sampler.get(item_ids)
      elif sampler_type == 'negative_sampler_v2':
        user_ids = field_dict[sampler_config.user_id_field]
        sampled = self._sampler.get(user_ids, item_ids)
      elif sampler_type.startswith('hard_negative_sampler'):
        user_ids = field_dict[sampler_config.user_id_field]
        sampled = self._sampler.get(user_ids, item_ids)
      else:
        raise ValueError('Unknown sampler %s' % sampler_type)
      for k, v in sampled.items():
        if k in field_dict:
          field_dict[k] = tf.concat([field_dict[k], v], axis=0)
        else:
          print('appended fields: %s' % k)
          parsed_dict[k] = v
          self._appended_fields.append(k)

    for fc in self._feature_configs:
      feature_name = fc.feature_name
      feature_type = fc.feature_type
      if feature_type == fc.TagFeature:
        self._parse_tag_feature(fc, parsed_dict, field_dict)
      elif feature_type == fc.LookupFeature:
        assert feature_name is not None and feature_name != ''
        assert len(fc.input_names) == 2
        parsed_dict[feature_name] = self._lookup_preprocess(fc, field_dict)
      elif feature_type == fc.SequenceFeature:
        self._parse_seq_feature(fc, parsed_dict, field_dict)
      elif feature_type == fc.RawFeature:
        self._parse_raw_feature(fc, parsed_dict, field_dict)
      elif feature_type == fc.IdFeature:
        self._parse_id_feature(fc, parsed_dict, field_dict)
      elif feature_type == fc.ExprFeature:
        self._parse_expr_feature(fc, parsed_dict, field_dict)
      elif feature_type == fc.ComboFeature:
        self._parse_combo_feature(fc, parsed_dict, field_dict)
      else:
        feature_name = fc.feature_name if fc.HasField(
            'feature_name') else fc.input_names[0]
        for input_id, input_name in enumerate(fc.input_names):
          if input_id > 0:
            key = feature_name + '_' + str(input_id)
          else:
            key = feature_name
          parsed_dict[key] = field_dict[input_name]

    label_dict = {}
    for input_id, input_name in enumerate(self._label_fields):
      if input_name not in field_dict:
        continue
      if input_name in self._label_udf_map:
        udf, udf_class, dtype = self._label_udf_map[input_name]
        if dtype is None or dtype == '':
          logging.info('apply tensorflow function transform: %s' % udf_class)
          field_dict[input_name] = udf(field_dict[input_name])
        else:
          assert dtype is not None, 'must set user_define_fn_res_type'
          logging.info('apply py_func transform: %s' % udf_class)
          field_dict[input_name] = tf.py_func(
              udf, [field_dict[input_name]], Tout=get_tf_type(dtype))
          field_dict[input_name].set_shape(tf.TensorShape([None]))

      if field_dict[input_name].dtype == tf.string:
        if self._label_dim[input_id] > 1:
          logging.info('will split labels[%d]=%s' % (input_id, input_name))
          check_list = [
              tf.py_func(
                  check_split, [
                      field_dict[input_name], self._label_sep[input_id],
                      self._label_dim[input_id], input_name
                  ],
                  Tout=tf.bool)
          ] if self._check_mode else []
          with tf.control_dependencies(check_list):
            label_dict[input_name] = tf.string_split(
                field_dict[input_name], self._label_sep[input_id]).values
            label_dict[input_name] = tf.reshape(label_dict[input_name],
                                                [-1, self._label_dim[input_id]])
        else:
          label_dict[input_name] = field_dict[input_name]
        check_list = [
            tf.py_func(
                check_string_to_number, [label_dict[input_name], input_name],
                Tout=tf.bool)
        ] if self._check_mode else []
        with tf.control_dependencies(check_list):
          label_dict[input_name] = tf.string_to_number(
              label_dict[input_name], tf.float32, name=input_name)
      else:
        assert field_dict[input_name].dtype in [
            tf.float32, tf.double, tf.int32, tf.int64
        ], 'invalid label dtype: %s' % str(field_dict[input_name].dtype)
        label_dict[input_name] = field_dict[input_name]

    if self._mode != tf.estimator.ModeKeys.PREDICT:
      for func_config in self._data_config.extra_label_func:
        lbl_name = func_config.label_name
        func_name = func_config.label_func
        logging.info('generating new label `%s` by transform: %s' %
                     (lbl_name, func_name))
        lbl_fn = load_by_path(func_name)
        label_dict[lbl_name] = lbl_fn(label_dict)

      if self._data_config.HasField('sample_weight'):
        parsed_dict[constant.SAMPLE_WEIGHT] = field_dict[
            self._data_config.sample_weight]

    if Input.DATA_OFFSET in field_dict:
      parsed_dict[Input.DATA_OFFSET] = field_dict[Input.DATA_OFFSET]
    return {'feature': parsed_dict, 'label': label_dict}