def parse()

in tensorflow_ranking/python/data.py [0:0]


  def parse(self, serialized):
    """See `_RankingDataParser`."""
    (serialized_context, serialized_list,
     sizes) = self._decode_as_serialized_example_list(serialized)

    # Use static batch size whenever possible.
    batch_size = serialized_context.get_shape().as_list()[0] or tf.shape(
        input=serialized_list)[0]
    cur_list_size = tf.shape(input=serialized_list)[1]
    list_size = self._list_size

    if self._shuffle_examples:
      is_valid = tf.sequence_mask(sizes, cur_list_size)
      indices = utils.shuffle_valid_indices(is_valid, seed=self._seed)
      serialized_list = tf.gather_nd(serialized_list, indices)

    # Apply truncation or padding to align tensor shape.
    if list_size:

      def truncate_fn():
        return tf.slice(serialized_list, [0, 0], [batch_size, list_size])

      def pad_fn():
        return tf.pad(
            tensor=serialized_list,
            paddings=[[0, 0], [0, list_size - cur_list_size]],
            constant_values="")

      serialized_list = tf.cond(
          pred=cur_list_size > list_size, true_fn=truncate_fn, false_fn=pad_fn)
      cur_list_size = list_size

    features = {}
    example_features = tf.compat.v1.io.parse_example(
        tf.reshape(serialized_list, [-1]), self._example_feature_spec)
    for k, v in six.iteritems(example_features):
      if isinstance(v, tf.RaggedTensor):
        # Reshape from [batch_size * cur_list_size, ...] to
        # [batch_size, cur_list_size, ...] for RaggedTensor.
        features[k] = tf.RaggedTensor.from_uniform_row_length(v, cur_list_size)
      else:
        features[k] = utils.reshape_first_ndims(v, 1,
                                                [batch_size, cur_list_size])

    if self._context_feature_spec:
      features.update(
          tf.compat.v1.io.parse_example(
              tf.reshape(serialized_context, [batch_size]),
              self._context_feature_spec))

    # Add example list sizes to features, if needed.
    if self._size_feature_name:
      features[self._size_feature_name] = sizes
    if self._mask_feature_name:
      features[self._mask_feature_name] = tf.sequence_mask(sizes, list_size)
    return features