def __init__()

in easy_rec/python/feature_column/feature_column.py [0:0]


  def __init__(self,
               feature_configs,
               wide_deep_dict={},
               wide_output_dim=-1,
               ev_params=None):
    """Initializes a `FeatureColumnParser`.

    Args:
      feature_configs: collections of
            easy_rec.python.protos.feature_config_pb2.FeatureConfig
            or easy_rec.python.protos.feature_config_pb2.FeatureConfigV2.features
      wide_deep_dict: dict of {feature_name:WideOrDeep}, passed by
        easy_rec.python.layers.input_layer.InputLayer, it is defined in
        easy_rec.python.protos.easy_rec_model_pb2.EasyRecModel.feature_groups
      wide_output_dim: output dimension for wide columns
      ev_params: params used by EmbeddingVariable, which is provided by pai-tf
    """
    self._feature_configs = feature_configs
    self._wide_output_dim = wide_output_dim
    self._wide_deep_dict = wide_deep_dict
    self._deep_columns = {}
    self._wide_columns = {}
    self._sequence_columns = {}

    self._share_embed_names = {}
    self._share_embed_infos = {}

    self._vocab_size = {}

    self._global_ev_params = None
    if ev_params is not None:
      self._global_ev_params = self._build_ev_params(ev_params)

    def _cmp_embed_config(a, b):
      return a.embedding_dim == b.embedding_dim and a.combiner == b.combiner and\
          a.initializer == b.initializer and a.max_partitions == b.max_partitions and\
          a.embedding_name == b.embedding_name

    for config in self._feature_configs:
      if not config.HasField('embedding_name'):
        continue
      embed_name = config.embedding_name

      if embed_name in self._share_embed_names:
        assert _cmp_embed_config(config, self._share_embed_infos[embed_name]),\
            'shared embed info of [%s] is not matched [%s] vs [%s]' % (
                embed_name, config, self._share_embed_infos[embed_name])
        self._share_embed_names[embed_name] += 1
        if config.feature_type == FeatureConfig.FeatureType.SequenceFeature:
          self._share_embed_infos[embed_name] = copy_obj(config)
      else:
        self._share_embed_names[embed_name] = 1
        self._share_embed_infos[embed_name] = copy_obj(config)

    # remove not shared embedding names
    not_shared = [
        x for x in self._share_embed_names if self._share_embed_names[x] == 1
    ]
    for embed_name in not_shared:
      del self._share_embed_names[embed_name]
      del self._share_embed_infos[embed_name]

    logging.info('shared embeddings[num=%d]' % len(self._share_embed_names))
    for embed_name in self._share_embed_names:
      logging.info('\t%s: share_num[%d], share_info[%s]' %
                   (embed_name, self._share_embed_names[embed_name],
                    self._share_embed_infos[embed_name]))
    self._deep_share_embed_columns = {
        embed_name: [] for embed_name in self._share_embed_names
    }
    self._wide_share_embed_columns = {
        embed_name: [] for embed_name in self._share_embed_names
    }

    self._feature_vocab_size = {}
    for config in self._feature_configs:
      assert isinstance(config, FeatureConfig)
      try:
        if config.feature_type == config.IdFeature:
          self.parse_id_feature(config)
        elif config.feature_type == config.TagFeature:
          self.parse_tag_feature(config)
        elif config.feature_type == config.RawFeature:
          self.parse_raw_feature(config)
        elif config.feature_type == config.ComboFeature:
          self.parse_combo_feature(config)
        elif config.feature_type == config.LookupFeature:
          self.parse_lookup_feature(config)
        elif config.feature_type == config.SequenceFeature:
          self.parse_sequence_feature(config)
        elif config.feature_type == config.ExprFeature:
          self.parse_expr_feature(config)
        elif config.feature_type != config.PassThroughFeature:
          assert False, 'invalid feature type: %s' % config.feature_type
      except FeatureKeyError:
        pass

    for embed_name in self._share_embed_names:
      initializer = None
      if self._share_embed_infos[embed_name].HasField('initializer'):
        initializer = hyperparams_builder.build_initializer(
            self._share_embed_infos[embed_name].initializer)

      partitioner = self._build_partitioner(self._share_embed_infos[embed_name])

      if self._share_embed_infos[embed_name].HasField('ev_params'):
        ev_params = self._build_ev_params(
            self._share_embed_infos[embed_name].ev_params)
      else:
        ev_params = self._global_ev_params

      # for handling share embedding columns
      if len(self._deep_share_embed_columns[embed_name]) > 0:
        share_embed_fcs = feature_column.shared_embedding_columns(
            self._deep_share_embed_columns[embed_name],
            self._share_embed_infos[embed_name].embedding_dim,
            initializer=initializer,
            shared_embedding_collection_name=embed_name,
            combiner=self._share_embed_infos[embed_name].combiner,
            partitioner=partitioner,
            ev_params=ev_params)
        config = self._share_embed_infos[embed_name]
        max_seq_len = config.max_seq_len if config.HasField(
            'max_seq_len') else -1
        for fc in share_embed_fcs:
          fc.max_seq_length = max_seq_len
        self._deep_share_embed_columns[embed_name] = share_embed_fcs

      # for handling wide share embedding columns
      if len(self._wide_share_embed_columns[embed_name]) > 0:
        share_embed_fcs = feature_column.shared_embedding_columns(
            self._wide_share_embed_columns[embed_name],
            self._wide_output_dim,
            initializer=initializer,
            shared_embedding_collection_name=embed_name + '_wide',
            combiner='sum',
            partitioner=partitioner,
            ev_params=ev_params)
        config = self._share_embed_infos[embed_name]
        max_seq_len = config.max_seq_len if config.HasField(
            'max_seq_len') else -1
        for fc in share_embed_fcs:
          fc.max_seq_length = max_seq_len
        self._wide_share_embed_columns[embed_name] = share_embed_fcs

    for fc_name in self._deep_columns:
      fc = self._deep_columns[fc_name]
      if isinstance(fc, SharedEmbedding):
        self._deep_columns[fc_name] = self._get_shared_embedding_column(fc)

    for fc_name in self._wide_columns:
      fc = self._wide_columns[fc_name]
      if isinstance(fc, SharedEmbedding):
        self._wide_columns[fc_name] = self._get_shared_embedding_column(
            fc, deep=False)

    for fc_name in self._sequence_columns:
      fc = self._sequence_columns[fc_name]
      if isinstance(fc, SharedEmbedding):
        self._sequence_columns[fc_name] = self._get_shared_embedding_column(fc)