in tensorflow_ranking/python/data.py [0:0]
def parse(self, serialized):
"""See `_RankingDataParser`."""
(serialized_context, serialized_list,
sizes) = self._decode_as_serialized_example_list(serialized)
# Use static batch size whenever possible.
batch_size = serialized_context.get_shape().as_list()[0] or tf.shape(
input=serialized_list)[0]
cur_list_size = tf.shape(input=serialized_list)[1]
list_size = self._list_size
if self._shuffle_examples:
is_valid = tf.sequence_mask(sizes, cur_list_size)
indices = utils.shuffle_valid_indices(is_valid, seed=self._seed)
serialized_list = tf.gather_nd(serialized_list, indices)
# Apply truncation or padding to align tensor shape.
if list_size:
def truncate_fn():
return tf.slice(serialized_list, [0, 0], [batch_size, list_size])
def pad_fn():
return tf.pad(
tensor=serialized_list,
paddings=[[0, 0], [0, list_size - cur_list_size]],
constant_values="")
serialized_list = tf.cond(
pred=cur_list_size > list_size, true_fn=truncate_fn, false_fn=pad_fn)
cur_list_size = list_size
features = {}
example_features = tf.compat.v1.io.parse_example(
tf.reshape(serialized_list, [-1]), self._example_feature_spec)
for k, v in six.iteritems(example_features):
if isinstance(v, tf.RaggedTensor):
# Reshape from [batch_size * cur_list_size, ...] to
# [batch_size, cur_list_size, ...] for RaggedTensor.
features[k] = tf.RaggedTensor.from_uniform_row_length(v, cur_list_size)
else:
features[k] = utils.reshape_first_ndims(v, 1,
[batch_size, cur_list_size])
if self._context_feature_spec:
features.update(
tf.compat.v1.io.parse_example(
tf.reshape(serialized_context, [batch_size]),
self._context_feature_spec))
# Add example list sizes to features, if needed.
if self._size_feature_name:
features[self._size_feature_name] = sizes
if self._mask_feature_name:
features[self._mask_feature_name] = tf.sequence_mask(sizes, list_size)
return features