def __init__()

in easy_rec/python/input/input.py [0:0]


  def __init__(self,
               data_config,
               feature_configs,
               input_path,
               task_index=0,
               task_num=1,
               check_mode=False,
               pipeline_config=None,
               **kwargs):
    self._pipeline_config = pipeline_config
    self._data_config = data_config
    self._check_mode = check_mode
    logging.info('check_mode: %s ' % self._check_mode)
    # tf.estimator.ModeKeys.*, only available before
    # calling self._build
    self._mode = None
    if pipeline_config is not None and pipeline_config.model_config.HasField(
        'ev_params'):
      self._has_ev = True
    else:
      self._has_ev = False

    if self._data_config.auto_expand_input_fields:
      input_fields = [x for x in self._data_config.input_fields]
      while len(self._data_config.input_fields) > 0:
        self._data_config.input_fields.pop()
      for field in input_fields:
        tmp_names = config_util.auto_expand_names(field.input_name)
        for tmp_name in tmp_names:
          one_field = DatasetConfig.Field()
          one_field.CopyFrom(field)
          one_field.input_name = tmp_name
          self._data_config.input_fields.append(one_field)

    self._input_fields = [x.input_name for x in data_config.input_fields]
    self._input_dims = [x.input_dim for x in data_config.input_fields]
    self._input_field_types = [x.input_type for x in data_config.input_fields]
    self._input_field_defaults = [
        x.default_val for x in data_config.input_fields
    ]
    self._label_fields = list(data_config.label_fields)
    self._feature_fields = list(data_config.feature_fields)
    self._label_sep = list(data_config.label_sep)
    self._label_dim = list(data_config.label_dim)
    if len(self._label_dim) < len(self._label_fields):
      for x in range(len(self._label_fields) - len(self._label_dim)):
        self._label_dim.append(1)

    self._label_udf_map = {}
    for config in self._data_config.input_fields:
      if config.HasField('user_define_fn'):
        self._label_udf_map[config.input_name] = self._load_label_fn(config)

    self._batch_size = data_config.batch_size
    self._prefetch_size = data_config.prefetch_size
    self._feature_configs = list(feature_configs)
    self._task_index = task_index
    self._task_num = task_num

    self._input_path = input_path

    # findout effective fields
    self._effective_fields = []

    # for multi value inputs, the types maybe different
    # from the types defined in input_fields
    # it is used in create_multi_placeholders
    self._multi_value_types = {}
    self._multi_value_fields = set()

    self._normalizer_fn = {}
    for fc in self._feature_configs:
      for input_name in fc.input_names:
        assert input_name in self._input_fields, 'invalid input_name in %s' % str(
            fc)
        if input_name not in self._effective_fields:
          self._effective_fields.append(input_name)

      if fc.feature_type in [fc.TagFeature, fc.SequenceFeature]:
        if fc.hash_bucket_size > 0 or len(
            fc.vocab_list) > 0 or fc.HasField('vocab_file'):
          self._multi_value_types[fc.input_names[0]] = tf.string
          self._multi_value_fields.add(fc.input_names[0])
        else:
          self._multi_value_types[fc.input_names[0]] = tf.int64
          self._multi_value_fields.add(fc.input_names[0])
        if len(fc.input_names) > 1:
          self._multi_value_types[fc.input_names[1]] = tf.float32
          self._multi_value_fields.add(fc.input_names[1])

      if fc.feature_type == fc.RawFeature and fc.raw_input_dim > 1:
        self._multi_value_types[fc.input_names[0]] = tf.float32
        self._multi_value_fields.add(fc.input_names[0])

      if fc.HasField('normalizer_fn'):
        feature_name = fc.feature_name if fc.HasField(
            'feature_name') else fc.input_names[0]
        self._normalizer_fn[feature_name] = load_by_path(fc.normalizer_fn)

    # add sample weight to effective fields
    if self._data_config.HasField('sample_weight'):
      self._effective_fields.append(self._data_config.sample_weight)

    # add uid_field of GAUC and session_fields of SessionAUC
    if self._pipeline_config is not None:
      metrics = self._pipeline_config.eval_config.metrics_set
      for metric in metrics:
        metric_name = metric.WhichOneof('metric')
        if metric_name == 'gauc':
          uid = metric.gauc.uid_field
          if uid not in self._effective_fields:
            self._effective_fields.append(uid)
        elif metric_name == 'session_auc':
          sid = metric.session_auc.session_id_field
          if sid not in self._effective_fields:
            self._effective_fields.append(sid)

      # check multi task model's metrics
      model_config = self._pipeline_config.model_config
      model_name = model_config.WhichOneof('model')
      if model_name in {'mmoe', 'esmm', 'dbmtl', 'simple_multi_task', 'ple'}:
        model = getattr(model_config, model_name)
        towers = [model.ctr_tower, model.cvr_tower
                  ] if model_name == 'esmm' else model.task_towers
        for tower in towers:
          metrics = tower.metrics_set
          for metric in metrics:
            metric_name = metric.WhichOneof('metric')
            if metric_name == 'gauc':
              uid = metric.gauc.uid_field
              if uid not in self._effective_fields:
                self._effective_fields.append(uid)
            elif metric_name == 'session_auc':
              sid = metric.session_auc.session_id_field
              if sid not in self._effective_fields:
                self._effective_fields.append(sid)

    self._effective_fids = [
        self._input_fields.index(x) for x in self._effective_fields
    ]
    # sort fids from small to large
    self._effective_fids = list(set(self._effective_fids))
    self._effective_fields = [
        self._input_fields[x] for x in self._effective_fids
    ]

    self._label_fids = [self._input_fields.index(x) for x in self._label_fields]

    # virtual fields generated by self._preprocess
    # which will be inputs to feature columns
    self._appended_fields = []

    # sampler
    self._sampler = None
    if input_path is not None:
      # build sampler only when train and eval
      self._sampler = sampler_lib.build(data_config)

    self.get_type_defaults = get_type_defaults