def embed_single_feature()

in research/carls/models/caml/sparse_features.py [0:0]


def embed_single_feature(keys: tf.Tensor,
                         de_config: de_config_pb2.DynamicEmbeddingConfig,
                         embedding_dim: int,
                         sigma_dim: int,
                         feature_name: typing.Text,
                         em_steps: int = 0,
                         variables: BagOfSparseFeatureVariables = None,
                         service_address: typing.Text = "",
                         timeout_ms: int = -1):
  """Embeds a single feature based on the embedding decomposition formula.

  Args:
    keys: A string `Tensor` of shape [batch_size] for single key batch or
      [batch_size, None] for multiple keys batch. Only non-empty strings are
      considered as valid key.
    de_config: Proto DynamicEmbeddingConfig that configs the embedding.
    embedding_dim: a positive int denoting the dimension for embedding.
    sigma_dim: a non-negative int denoting the dimension for sigma
      function. If sigma_dim = 0, no weighted composition is applied to
      the input, simply an average of the embeddings for each feature is
      returned.
    feature_name: Feature name for the embedding.
    em_steps: number of training steps in each iteration of optimizing sigma and
      embedding alternatively, which is done by not updating the gradients of
      the other. If em_steps <= 0, both sigma and embeddings are optimized
      simultaneously. A proper em_steps can help significantly reduce the
      generalization error.
    variables: a `BagOfSparseFeatureVariables` denoting the variables used for
      computing the CA-BSFE embedding. If None, creates new variables.
    service_address: The address of a knowledge bank service. If empty, the
      value passed from --kbs_address flag will be used instead.
    timeout_ms: Timeout millseconds for the connection. If negative, never
      timout.

  Returns:
    A tuple of
    - embedding: A `Tensor` of shape [batch_size, embedding_dim]
      representing composited embedding vector.
    - vc: A `Tensor` of shape [embedding_dim] representing the
      context vector.
    - sigma: A `Tensor` of shape [batch_size] representing the context free
      probability.
    - input_embedding: A `Tensor` of shape [batch_size, embedding_dim]
      (2D) or [batch_size, max_sequence_length, embedding_dim] (3D)
      representing the input embedding.
    - variables: A list of tf.Variable defined in this function.
  Raises:
    TypeError: If de_config is not an instance of DynamicEmbeddingConfig.
    ValueError: If feature_name is not specified, or sigma_dim < 0 or
      embedding_dim <= 0.
  """
  if not isinstance(de_config, de_config_pb2.DynamicEmbeddingConfig):
    raise TypeError("de_config must be an instance of DynamicEmbeddingConfig.")
  if sigma_dim < 0:
    raise ValueError("Invalid sigma_dim: %d" % sigma_dim)
  if embedding_dim <= 0:
    raise ValueError("Invalid embedding_dim: %d" % embedding_dim)
  if not feature_name:
    raise ValueError("Must specify a valid feature_name.")
  # A single key batch is a [batch_size] input.
  if not isinstance(keys, tf.Tensor):
    keys = tf.convert_to_tensor(keys)
  is_single_key_batch = (len(keys.get_shape().as_list()) == 1)

  # Add to global collection of feature embeddings for export.
  bsfe_params = _BagOfSparseFeaturesEmbeddingParams(embedding_dim,
                                                    sigma_dim)
  with _lock:
    _feature_embedding_collections[feature_name] = bsfe_params

  # Case One: the simplest case when input is a batch of single keys like
  # ['a', 'b', 'c']. Just returns dynamic embedding lookup for each key.
  if sigma_dim == 0 and is_single_key_batch:
    embedding, _ = _partitioned_dynamic_embedding_lookup(
        keys,
        de_config,
        embedding_dim,
        sigma_dim,
        feature_name,
        service_address=service_address,
        timeout_ms=timeout_ms)
    return embedding, None, None, embedding, None

  # Define context vector and sigma function parameters.
  if sigma_dim > 0:
    if variables:
      vc = variables.context_free_vector
      sigma_kernel = variables.sigma_kernel
      sigma_bias = variables.sigma_bias
    else:
      vc = tf.Variable(
          tf.random.normal([embedding_dim]), name="%s_vc" % feature_name)
      sigma_kernel = tf.Variable(
          tf.random.normal([sigma_dim]),
          name="%s_sigma_kernal" % feature_name)
      sigma_bias = tf.Variable([0.0], name="%s_sigma_bias" % feature_name)

  input_embedding, sigma_emb = _partitioned_dynamic_embedding_lookup(
      keys,
      de_config,
      embedding_dim,
      sigma_dim,
      feature_name,
      service_address=service_address,
      timeout_ms=timeout_ms)

  # Allows sigma() and embedding be trained alternatively (every `em_steps`)
  # rather than simultaneously.
  global_step = tf.compat.v1.train.get_global_step()
  if global_step is not None and sigma_dim > 0 and em_steps > 0:
    should_update_embedding = tf.equal(tf.mod(global_step / em_steps, 2), 0)
    should_update_sigma = tf.equal(tf.mod(global_step / em_steps, 2), 1)

    # pylint: disable=g-long-lambda
    input_embedding = tf.cond(should_update_embedding, lambda: input_embedding,
                              lambda: tf.stop_gradient(input_embedding))
    sigma_emb = tf.cond(should_update_sigma, lambda: sigma_emb,
                        lambda: tf.stop_gradient(sigma_emb))
    # Without the following two statements also works.
    sigma_kernel = tf.cond(should_update_sigma, lambda: sigma_kernel,
                           lambda: tf.stop_gradient(sigma_kernel))
    sigma_bias = tf.cond(should_update_sigma, lambda: sigma_bias,
                         lambda: tf.stop_gradient(sigma_bias))
    # pylint: enable=g-long-lambda

  # `variables` comes from either input or local definition.
  if sigma_dim > 0 and variables is None:
    variables = BagOfSparseFeatureVariables(vc, sigma_kernel, sigma_bias)

  # Case Two: input is a batch of keys but sigma embedding is non-zero.
  # It reduces to computing the embedding decomposition for each input, i.e.,
  # [emb(x) = sigma(x) * vc + (1 - sigma(x)) * emb_i(x) for x in keys].
  if is_single_key_batch:  # and sigma_dim != 0
    # shape [batch_size, sigma_dim]
    sigma = tf.matmul(sigma_emb, tf.expand_dims(sigma_kernel, [-1]))
    sigma = tf.sigmoid(tf.nn.bias_add(sigma, sigma_bias))
    # shape [batch_size, embedding_dim]
    embedding = tf.reshape(input_embedding, [-1, embedding_dim])
    embedding = sigma * vc + (1 - sigma) * embedding
    return embedding, vc, sigma, input_embedding, variables

  # Case Three: the rank of input keys > 1, e.g., [['a', 'b'], ['c', '']].
  # The bag of sparse features embedding for each example is computed.
  shape_list = _get_shape_as_list(keys)
  shape_list.append(embedding_dim)
  if sigma_dim > 0:
    sigma_emb = tf.reshape(sigma_emb, [-1, sigma_dim])
    sigma = tf.matmul(sigma_emb, tf.expand_dims(sigma_kernel, [-1]))
    sigma = tf.sigmoid(tf.nn.bias_add(sigma, sigma_bias))
    embedding = tf.reshape(input_embedding, [-1, embedding_dim])
    embedding = (sigma * vc + (1 - sigma) * embedding)
    embedding = tf.reshape(embedding, shape_list)
    sigma = tf.reshape(sigma, shape_list[:-1])
  else:
    embedding = input_embedding
    sigma = None
    vc = None

  # Only computes the average embeddings over the non-empty features.
  mask = tf.cast(tf.not_equal(keys, tf.zeros_like(keys)), dtype=tf.float32)
  mask = tf.reduce_sum(mask, -1)
  mask = tf.where(tf.equal(mask, tf.zeros_like(mask)), tf.ones_like(mask), mask)
  mask = 1 / mask
  mask = tf.expand_dims(mask, -1)  # [batch_size, 1]
  tile_shape = _get_shape_as_list(keys)
  for i in range(len(tile_shape)):
    tile_shape[i] = 1
  tile_shape[-1] = embedding_dim
  mask = tf.tile(mask, tile_shape)
  embedding = tf.reduce_sum(embedding, -2)
  embedding *= mask
  return (embedding, vc, sigma, input_embedding,
          variables if sigma_dim > 0 else None)