in research/carls/models/caml/sparse_features.py [0:0]
def embed_single_feature(keys: tf.Tensor,
de_config: de_config_pb2.DynamicEmbeddingConfig,
embedding_dim: int,
sigma_dim: int,
feature_name: typing.Text,
em_steps: int = 0,
variables: BagOfSparseFeatureVariables = None,
service_address: typing.Text = "",
timeout_ms: int = -1):
"""Embeds a single feature based on the embedding decomposition formula.
Args:
keys: A string `Tensor` of shape [batch_size] for single key batch or
[batch_size, None] for multiple keys batch. Only non-empty strings are
considered as valid key.
de_config: Proto DynamicEmbeddingConfig that configs the embedding.
embedding_dim: a positive int denoting the dimension for embedding.
sigma_dim: a non-negative int denoting the dimension for sigma
function. If sigma_dim = 0, no weighted composition is applied to
the input, simply an average of the embeddings for each feature is
returned.
feature_name: Feature name for the embedding.
em_steps: number of training steps in each iteration of optimizing sigma and
embedding alternatively, which is done by not updating the gradients of
the other. If em_steps <= 0, both sigma and embeddings are optimized
simultaneously. A proper em_steps can help significantly reduce the
generalization error.
variables: a `BagOfSparseFeatureVariables` denoting the variables used for
computing the CA-BSFE embedding. If None, creates new variables.
service_address: The address of a knowledge bank service. If empty, the
value passed from --kbs_address flag will be used instead.
timeout_ms: Timeout millseconds for the connection. If negative, never
timout.
Returns:
A tuple of
- embedding: A `Tensor` of shape [batch_size, embedding_dim]
representing composited embedding vector.
- vc: A `Tensor` of shape [embedding_dim] representing the
context vector.
- sigma: A `Tensor` of shape [batch_size] representing the context free
probability.
- input_embedding: A `Tensor` of shape [batch_size, embedding_dim]
(2D) or [batch_size, max_sequence_length, embedding_dim] (3D)
representing the input embedding.
- variables: A list of tf.Variable defined in this function.
Raises:
TypeError: If de_config is not an instance of DynamicEmbeddingConfig.
ValueError: If feature_name is not specified, or sigma_dim < 0 or
embedding_dim <= 0.
"""
if not isinstance(de_config, de_config_pb2.DynamicEmbeddingConfig):
raise TypeError("de_config must be an instance of DynamicEmbeddingConfig.")
if sigma_dim < 0:
raise ValueError("Invalid sigma_dim: %d" % sigma_dim)
if embedding_dim <= 0:
raise ValueError("Invalid embedding_dim: %d" % embedding_dim)
if not feature_name:
raise ValueError("Must specify a valid feature_name.")
# A single key batch is a [batch_size] input.
if not isinstance(keys, tf.Tensor):
keys = tf.convert_to_tensor(keys)
is_single_key_batch = (len(keys.get_shape().as_list()) == 1)
# Add to global collection of feature embeddings for export.
bsfe_params = _BagOfSparseFeaturesEmbeddingParams(embedding_dim,
sigma_dim)
with _lock:
_feature_embedding_collections[feature_name] = bsfe_params
# Case One: the simplest case when input is a batch of single keys like
# ['a', 'b', 'c']. Just returns dynamic embedding lookup for each key.
if sigma_dim == 0 and is_single_key_batch:
embedding, _ = _partitioned_dynamic_embedding_lookup(
keys,
de_config,
embedding_dim,
sigma_dim,
feature_name,
service_address=service_address,
timeout_ms=timeout_ms)
return embedding, None, None, embedding, None
# Define context vector and sigma function parameters.
if sigma_dim > 0:
if variables:
vc = variables.context_free_vector
sigma_kernel = variables.sigma_kernel
sigma_bias = variables.sigma_bias
else:
vc = tf.Variable(
tf.random.normal([embedding_dim]), name="%s_vc" % feature_name)
sigma_kernel = tf.Variable(
tf.random.normal([sigma_dim]),
name="%s_sigma_kernal" % feature_name)
sigma_bias = tf.Variable([0.0], name="%s_sigma_bias" % feature_name)
input_embedding, sigma_emb = _partitioned_dynamic_embedding_lookup(
keys,
de_config,
embedding_dim,
sigma_dim,
feature_name,
service_address=service_address,
timeout_ms=timeout_ms)
# Allows sigma() and embedding be trained alternatively (every `em_steps`)
# rather than simultaneously.
global_step = tf.compat.v1.train.get_global_step()
if global_step is not None and sigma_dim > 0 and em_steps > 0:
should_update_embedding = tf.equal(tf.mod(global_step / em_steps, 2), 0)
should_update_sigma = tf.equal(tf.mod(global_step / em_steps, 2), 1)
# pylint: disable=g-long-lambda
input_embedding = tf.cond(should_update_embedding, lambda: input_embedding,
lambda: tf.stop_gradient(input_embedding))
sigma_emb = tf.cond(should_update_sigma, lambda: sigma_emb,
lambda: tf.stop_gradient(sigma_emb))
# Without the following two statements also works.
sigma_kernel = tf.cond(should_update_sigma, lambda: sigma_kernel,
lambda: tf.stop_gradient(sigma_kernel))
sigma_bias = tf.cond(should_update_sigma, lambda: sigma_bias,
lambda: tf.stop_gradient(sigma_bias))
# pylint: enable=g-long-lambda
# `variables` comes from either input or local definition.
if sigma_dim > 0 and variables is None:
variables = BagOfSparseFeatureVariables(vc, sigma_kernel, sigma_bias)
# Case Two: input is a batch of keys but sigma embedding is non-zero.
# It reduces to computing the embedding decomposition for each input, i.e.,
# [emb(x) = sigma(x) * vc + (1 - sigma(x)) * emb_i(x) for x in keys].
if is_single_key_batch: # and sigma_dim != 0
# shape [batch_size, sigma_dim]
sigma = tf.matmul(sigma_emb, tf.expand_dims(sigma_kernel, [-1]))
sigma = tf.sigmoid(tf.nn.bias_add(sigma, sigma_bias))
# shape [batch_size, embedding_dim]
embedding = tf.reshape(input_embedding, [-1, embedding_dim])
embedding = sigma * vc + (1 - sigma) * embedding
return embedding, vc, sigma, input_embedding, variables
# Case Three: the rank of input keys > 1, e.g., [['a', 'b'], ['c', '']].
# The bag of sparse features embedding for each example is computed.
shape_list = _get_shape_as_list(keys)
shape_list.append(embedding_dim)
if sigma_dim > 0:
sigma_emb = tf.reshape(sigma_emb, [-1, sigma_dim])
sigma = tf.matmul(sigma_emb, tf.expand_dims(sigma_kernel, [-1]))
sigma = tf.sigmoid(tf.nn.bias_add(sigma, sigma_bias))
embedding = tf.reshape(input_embedding, [-1, embedding_dim])
embedding = (sigma * vc + (1 - sigma) * embedding)
embedding = tf.reshape(embedding, shape_list)
sigma = tf.reshape(sigma, shape_list[:-1])
else:
embedding = input_embedding
sigma = None
vc = None
# Only computes the average embeddings over the non-empty features.
mask = tf.cast(tf.not_equal(keys, tf.zeros_like(keys)), dtype=tf.float32)
mask = tf.reduce_sum(mask, -1)
mask = tf.where(tf.equal(mask, tf.zeros_like(mask)), tf.ones_like(mask), mask)
mask = 1 / mask
mask = tf.expand_dims(mask, -1) # [batch_size, 1]
tile_shape = _get_shape_as_list(keys)
for i in range(len(tile_shape)):
tile_shape[i] = 1
tile_shape[-1] = embedding_dim
mask = tf.tile(mask, tile_shape)
embedding = tf.reduce_sum(embedding, -2)
embedding *= mask
return (embedding, vc, sigma, input_embedding,
variables if sigma_dim > 0 else None)