in tensorflow_ranking/python/data.py [0:0]
def parse(self, serialized):
"""See `_RankingDataParser`."""
if self._shuffle_examples:
raise ValueError(
"Shuffling examples is not supported in SequenceExample format.")
list_size = self._list_size
context_feature_spec = self._context_feature_spec
example_feature_spec = self._example_feature_spec
# Convert `FixedLenFeature` in `example_feature_spec` to
# `FixedLenSequenceFeature` to parse the `feature_lists` in SequenceExample.
# In addition, we collect non-trivial `default_value`s (neither "" nor 0)
# for post-processing. This is because no `default_value` except None is
# allowed for `FixedLenSequenceFeature`. Also, we set allow_missing=True and
# handle the missing feature_list later.
fixed_len_sequence_features = {}
padding_values = {}
non_trivial_padding_values = {}
for k, s in six.iteritems(example_feature_spec):
if not isinstance(s, tf.io.FixedLenFeature):
continue
fixed_len_sequence_features[k] = tf.io.FixedLenSequenceFeature(
s.shape, s.dtype, allow_missing=True)
scalar = _get_scalar_default_value(s.dtype, s.default_value)
padding_values[k] = scalar
if scalar:
non_trivial_padding_values[k] = scalar
sequence_features = example_feature_spec.copy()
sequence_features.update(fixed_len_sequence_features)
context, examples, sizes = tf.io.parse_sequence_example(
serialized,
context_features=context_feature_spec,
sequence_features=sequence_features)
# Infer sizes from ragged and sparse tensors if there are no dense features
# to infer it from.
if not sizes:
for name, example in examples.items():
if isinstance(example, tf.RaggedTensor):
sizes[name] = example.row_lengths
elif isinstance(example, tf.sparse.SparseTensor):
sizes[name] = 1 + tf.math.segment_max(example.indices[:, 1],
example.indices[:, 0])
# Reset to no trivial padding values for example features.
for k, v in six.iteritems(non_trivial_padding_values):
tensor = examples[k] # [batch_size, num_frames, feature_size]
tensor.get_shape().assert_has_rank(3)
size = tf.reshape(sizes[k], [-1, 1, 1]) # [batch_size, 1, 1]
rank = tf.reshape(
tf.tile(
tf.range(tf.shape(input=tensor)[1]), [tf.shape(input=tensor)[0]]),
tf.shape(input=tensor))
tensor = tf.compat.v1.where(
tf.less(rank, tf.cast(size, tf.int32)), tensor,
tf.fill(tf.shape(input=tensor), tf.cast(v, tensor.dtype)))
examples[k] = tensor
list_size_arg = list_size
if list_size is None:
# Use dynamic list_size. This is needed to pad missing feature_list.
list_size_dynamic = tf.reduce_max(
input_tensor=tf.stack(
[_bounding_shape(t)[1] for t in six.itervalues(examples)]))
list_size = list_size_dynamic
# Collect features. Truncate or pad example features to normalize the tensor
# shape: [batch_size, num_frames, ...] --> [batch_size, list_size, ...]
features = {}
features.update(context)
for k, t in six.iteritems(examples):
# Old shape: [batch_size, num_frames, ...]
shape = _bounding_shape(t)
ndims = shape.shape[0]
num_frames = shape[1]
# New shape: [batch_size, list_size, ...]
new_shape = tf.concat([[shape[0], list_size], shape[2:]], 0)
def truncate_fn(t=t, ndims=ndims, new_shape=new_shape):
"""Truncates the tensor."""
if isinstance(t, tf.sparse.SparseTensor):
return tf.sparse.slice(t, [0] * ndims,
tf.cast(new_shape, dtype=tf.int64))
elif isinstance(t, tf.RaggedTensor):
return t[:, :new_shape[1], ...]
else:
return tf.slice(t, [0] * ndims, new_shape)
def pad_fn(k=k,
t=t,
ndims=ndims,
num_frames=num_frames,
new_shape=new_shape):
"""Pads the tensor."""
if isinstance(t, tf.sparse.SparseTensor):
return tf.sparse.reset_shape(t, new_shape)
elif isinstance(t, tf.RaggedTensor):
# Convert ragged to a flattened sparse tensor with shape
# [batch_size * list_size, ...].
sparse_tensor = tf.sparse.reset_shape(t.to_sparse(), new_shape)
flattened_shape = tf.concat(
[[new_shape[0] * list_size], new_shape[2:]], axis=0)
flattened_sparse_tensor = tf.sparse.reshape(sparse_tensor,
flattened_shape)
# Convert to a ragged tensor of shape [batch_size * list_size, ...].
ragged_tensor = tf.RaggedTensor.from_sparse(
flattened_sparse_tensor, row_splits_dtype=t.row_splits.dtype)
# Reshape ragged tensor to [batch_size, list_size, ...]
return tf.RaggedTensor.from_uniform_row_length(
ragged_tensor, list_size)
else:
# Paddings has shape [n, 2] where n is the rank of the tensor.
paddings = tf.stack([[0, 0], [0, list_size - num_frames]] + [[0, 0]] *
(ndims - 2))
return tf.pad(
tensor=t, paddings=paddings, constant_values=padding_values[k])
tensor = tf.cond(
pred=num_frames > list_size, true_fn=truncate_fn, false_fn=pad_fn)
# Infer static shape for Tensor. Set the 2nd dim to None and set_shape
# merges `static_shape` with the existing static shape of the tensor.
if not isinstance(tensor, (tf.sparse.SparseTensor, tf.RaggedTensor)):
static_shape = t.get_shape().as_list()
static_shape[1] = list_size_arg
tensor.set_shape(static_shape)
features[k] = tensor
example_sizes = tf.stack(list(sizes.values()), axis=1)
sizes = tf.reduce_max(input_tensor=example_sizes, axis=1)
if self._size_feature_name:
features[self._size_feature_name] = sizes
if self._mask_feature_name:
features[self._mask_feature_name] = tf.sequence_mask(sizes, list_size)
return features