def _get_dense_tensor_internal()

in easy_rec/python/compat/feature_column/feature_column.py [0:0]


  def _get_dense_tensor_internal(self,
                                 inputs,
                                 weight_collections=None,
                                 trainable=None):
    """Private method that follows the signature of _get_dense_tensor."""
    # This method is called from a variable_scope with name _var_scope_name,
    # which is shared among all shared embeddings. Open a name_scope here, so
    # that the ops for different columns have distinct names.
    with ops.name_scope(None, default_name=self.name):
      # Get sparse IDs and weights.
      sparse_tensors = self.categorical_column._get_sparse_tensors(  # pylint: disable=protected-access
          inputs,
          weight_collections=weight_collections,
          trainable=trainable)
      sparse_ids = sparse_tensors.id_tensor
      sparse_weights = sparse_tensors.weight_tensor

      embedding_shape = (self.categorical_column._num_buckets, self.dimension)  # pylint: disable=protected-access
      shared_embedding_collection = ops.get_collection(
          self.shared_embedding_collection_name)
      if shared_embedding_collection:
        if len(shared_embedding_collection) > 1:
          raise ValueError(
              'Collection {} can only contain one variable. '
              'Suggested fix A: Choose a unique name for this collection. '
              'Suggested fix B: Do not add any variables to this collection. '
              'The feature_column library already adds a variable under the '
              'hood.'.format(shared_embedding_collection))
        embedding_weights = shared_embedding_collection[0]
        if embedding_weights.get_shape(
        ) != embedding_shape and not self.ev_params is not None:  # noqa : E714
          raise ValueError(
              'Shared embedding collection {} contains variable {} of '
              'unexpected shape {}. Expected shape is {}. '
              'Suggested fix A: Choose a unique name for this collection. '
              'Suggested fix B: Do not add any variables to this collection. '
              'The feature_column library already adds a variable under the '
              'hood.'.format(self.shared_embedding_collection_name,
                             embedding_weights.name,
                             embedding_weights.get_shape(), embedding_shape))
      else:
        if self.ev_params is None:
          embedding_weights = variable_scope.get_variable(
              name='embedding_weights',
              shape=embedding_shape,
              dtype=dtypes.float32,
              initializer=self.initializer,
              trainable=self.trainable and trainable,
              partitioner=self.partitioner,
              collections=weight_collections)
        else:
          # at eval or inference time, it is necessary to set
          # the initializers to zeros, so that new key will
          # get zero embedding
          if os.environ.get('tf.estimator.mode', '') != \
             os.environ.get('tf.estimator.ModeKeys.TRAIN', 'train'):
            initializer = init_ops.zeros_initializer()
          else:
            initializer = self.initializer
          extra_args = {}
          if 'EmbeddingVariableConfig' in dir(variables):
            ev_option = variables.EmbeddingVariableOption()
            ev_option.filter_strategy = variables.CounterFilter(
                filter_freq=self.ev_params.filter_freq)
            extra_args['ev_option'] = ev_option
          else:
            extra_args['filter_options'] = variables.CounterFilterOptions(
                self.ev_params.filter_freq)
          embedding_weights = variable_scope.get_embedding_variable(
              name='embedding_weights',
              embedding_dim=self.dimension,
              initializer=initializer,
              trainable=self.trainable and trainable,
              partitioner=self.partitioner,
              collections=weight_collections,
              steps_to_live=self.ev_params.steps_to_live,
              **extra_args)

        ops.add_to_collection(self.shared_embedding_collection_name,
                              embedding_weights)
      if self.ckpt_to_load_from is not None:
        to_restore = embedding_weights
        if isinstance(to_restore, variables.PartitionedVariable):
          to_restore = to_restore._get_variable_list()  # pylint: disable=protected-access
        checkpoint_utils.init_from_checkpoint(
            self.ckpt_to_load_from, {self.tensor_name_in_ckpt: to_restore})

      if 'RaggedTensor' in str(type(sparse_ids)):
        assert sparse_weights is None
        return embedding_lookup_ragged(
            embedding_weights=embedding_weights,
            ragged_ids=sparse_ids,
            ragged_weights=sparse_weights,
            combiner=self.combiner,
            max_norm=self.max_norm,
            name='%s_weights' % self.name)

      # Return embedding lookup result.
      return embedding_ops.safe_embedding_lookup_sparse(
          embedding_weights=embedding_weights,
          sparse_ids=sparse_ids,
          sparse_weights=sparse_weights,
          combiner=self.combiner,
          name='%s_weights' % self.name,
          max_norm=self.max_norm)