easy_rec/python/compat/embedding_ops.py (91 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Add embedding column for EmbeddingVariable which is only available on pai."""
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
def _prune_invalid_ids(sparse_ids, sparse_weights):
"""Prune invalid IDs (< 0) from the input ids and weights."""
is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
if sparse_weights is not None:
is_id_valid = math_ops.logical_and(
is_id_valid,
array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
if sparse_weights is not None:
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
return sparse_ids, sparse_weights
def _prune_invalid_weights(sparse_ids, sparse_weights):
"""Prune invalid weights (< 0) from the input ids and weights."""
if sparse_weights is not None:
is_weights_valid = math_ops.greater(sparse_weights.values, 0)
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
return sparse_ids, sparse_weights
def safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
sparse_weights=None,
combiner='mean',
default_id=None,
name=None,
partition_strategy='div',
max_norm=None):
"""Lookup embedding results, accounting for invalid IDs and empty features.
Fixed so that could be used with Pai EmbeddingVariables.
The partitioned embedding in `embedding_weights` must all be the same shape
except for the first dimension. The first dimension is allowed to vary as the
vocabulary size is not necessarily a multiple of `P`. `embedding_weights`
may be a `PartitionedVariable` as returned by using `tf.get_variable()` with a
partitioner.
Invalid IDs (< 0) are pruned from input IDs and weights, as well as any IDs
with non-positive weight. For an entry with no features, the embedding vector
for `default_id` is returned, or the 0-vector if `default_id` is not supplied.
The ids and weights may be multi-dimensional. Embeddings are always aggregated
along the last dimension.
Args:
embedding_weights: A list of `P` float `Tensor`s or values representing
partitioned embedding `Tensor`s. Alternatively, a `PartitionedVariable`
created by partitioning along dimension 0. The total unpartitioned
shape should be `[e_0, e_1, ..., e_m]`, where `e_0` represents the
vocab size and `e_1, ..., e_m` are the embedding dimensions.
sparse_ids: `SparseTensor` of shape `[d_0, d_1, ..., d_n]` containing the
ids. `d_0` is typically batch size.
sparse_weights: `SparseTensor` of same shape as `sparse_ids`, containing
float weights corresponding to `sparse_ids`, or `None` if all weights
are be assumed to be 1.0.
combiner: A string specifying how to combine embedding results for each
entry. Currently "mean", "sqrtn" and "sum" are supported, with "mean"
the default.
default_id: The id to use for an entry with no features.
name: A name for this operation (optional).
partition_strategy: A string specifying the partitioning strategy.
Currently `"div"` and `"mod"` are supported. Default is `"div"`.
max_norm: If not `None`, all embeddings are l2-normalized to max_norm before
combining.
Returns:
Dense `Tensor` of shape `[d_0, d_1, ..., d_{n-1}, e_1, ..., e_m]`.
Raises:
ValueError: if `embedding_weights` is empty.
"""
if embedding_weights is None:
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
embed_tensors = [ops.convert_to_tensor(embedding_weights)]
with ops.name_scope(name, 'embedding_lookup',
embed_tensors + [sparse_ids, sparse_weights]) as scope:
# Reshape higher-rank sparse ids and weights to linear segment ids.
original_shape = sparse_ids.dense_shape
original_rank_dim = sparse_ids.dense_shape.get_shape()[0]
original_rank = (
array_ops.size(original_shape)
if original_rank_dim.value is None else original_rank_dim.value)
sparse_ids = sparse_ops.sparse_reshape(sparse_ids, [
math_ops.reduce_prod(
array_ops.slice(original_shape, [0], [original_rank - 1])),
array_ops.gather(original_shape, original_rank - 1)
])
if sparse_weights is not None:
sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices,
sparse_weights.values,
sparse_ids.dense_shape)
# Prune invalid ids and weights.
sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
if combiner != 'sum':
sparse_ids, sparse_weights = _prune_invalid_weights(
sparse_ids, sparse_weights)
# Fill in dummy values for empty features, if necessary.
sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(
sparse_ids, default_id or 0)
if sparse_weights is not None:
sparse_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights, 1.0)
indices = sparse_ids.indices
values = sparse_ids.values
if values.dtype != dtypes.int64:
values = math_ops.to_int64(values)
sparse_ids = sparse_tensor.SparseTensor(
indices=indices, values=values, dense_shape=sparse_ids.dense_shape)
result = embedding_ops.embedding_lookup_sparse(
embedding_weights,
sparse_ids,
sparse_weights,
combiner=combiner,
partition_strategy=partition_strategy,
name=None if default_id is None else scope,
max_norm=max_norm)
if default_id is None:
# Broadcast is_row_empty to the same shape as embedding_lookup_result,
# for use in Select.
is_row_empty = array_ops.tile(
array_ops.reshape(is_row_empty, [-1, 1]),
array_ops.stack([1, array_ops.shape(result)[1]]))
result = array_ops.where(
is_row_empty, array_ops.zeros_like(result), result, name=scope)
# Reshape back from linear ids back into higher-dimensional dense result.
final_result = array_ops.reshape(
result,
array_ops.concat([
array_ops.slice(
math_ops.cast(original_shape, dtypes.int32), [0],
[original_rank - 1]),
array_ops.slice(array_ops.shape(result), [1], [-1])
], 0))
final_result.set_shape(
tensor_shape.unknown_shape(
(original_rank_dim - 1).value).concatenate(result.get_shape()[1:]))
return final_result