def parse()

in tensorflow_ranking/python/data.py [0:0]


  def parse(self, serialized):
    """See `_RankingDataParser`."""
    if self._shuffle_examples:
      raise ValueError(
          "Shuffling examples is not supported in SequenceExample format.")

    list_size = self._list_size
    context_feature_spec = self._context_feature_spec
    example_feature_spec = self._example_feature_spec
    # Convert `FixedLenFeature` in `example_feature_spec` to
    # `FixedLenSequenceFeature` to parse the `feature_lists` in SequenceExample.
    # In addition, we collect non-trivial `default_value`s (neither "" nor 0)
    # for post-processing. This is because no `default_value` except None is
    # allowed for `FixedLenSequenceFeature`. Also, we set allow_missing=True and
    # handle the missing feature_list later.
    fixed_len_sequence_features = {}
    padding_values = {}
    non_trivial_padding_values = {}
    for k, s in six.iteritems(example_feature_spec):
      if not isinstance(s, tf.io.FixedLenFeature):
        continue
      fixed_len_sequence_features[k] = tf.io.FixedLenSequenceFeature(
          s.shape, s.dtype, allow_missing=True)
      scalar = _get_scalar_default_value(s.dtype, s.default_value)
      padding_values[k] = scalar
      if scalar:
        non_trivial_padding_values[k] = scalar

    sequence_features = example_feature_spec.copy()
    sequence_features.update(fixed_len_sequence_features)
    context, examples, sizes = tf.io.parse_sequence_example(
        serialized,
        context_features=context_feature_spec,
        sequence_features=sequence_features)

    # Infer sizes from ragged and sparse tensors if there are no dense features
    # to infer it from.
    if not sizes:
      for name, example in examples.items():
        if isinstance(example, tf.RaggedTensor):
          sizes[name] = example.row_lengths
        elif isinstance(example, tf.sparse.SparseTensor):
          sizes[name] = 1 + tf.math.segment_max(example.indices[:, 1],
                                                example.indices[:, 0])

    # Reset to no trivial padding values for example features.
    for k, v in six.iteritems(non_trivial_padding_values):
      tensor = examples[k]  # [batch_size, num_frames, feature_size]
      tensor.get_shape().assert_has_rank(3)
      size = tf.reshape(sizes[k], [-1, 1, 1])  # [batch_size, 1, 1]
      rank = tf.reshape(
          tf.tile(
              tf.range(tf.shape(input=tensor)[1]), [tf.shape(input=tensor)[0]]),
          tf.shape(input=tensor))
      tensor = tf.compat.v1.where(
          tf.less(rank, tf.cast(size, tf.int32)), tensor,
          tf.fill(tf.shape(input=tensor), tf.cast(v, tensor.dtype)))
      examples[k] = tensor

    list_size_arg = list_size
    if list_size is None:
      # Use dynamic list_size. This is needed to pad missing feature_list.
      list_size_dynamic = tf.reduce_max(
          input_tensor=tf.stack(
              [_bounding_shape(t)[1] for t in six.itervalues(examples)]))
      list_size = list_size_dynamic

    # Collect features. Truncate or pad example features to normalize the tensor
    # shape: [batch_size, num_frames, ...] --> [batch_size, list_size, ...]
    features = {}
    features.update(context)
    for k, t in six.iteritems(examples):
      # Old shape: [batch_size, num_frames, ...]
      shape = _bounding_shape(t)
      ndims = shape.shape[0]
      num_frames = shape[1]
      # New shape: [batch_size, list_size, ...]
      new_shape = tf.concat([[shape[0], list_size], shape[2:]], 0)

      def truncate_fn(t=t, ndims=ndims, new_shape=new_shape):
        """Truncates the tensor."""
        if isinstance(t, tf.sparse.SparseTensor):
          return tf.sparse.slice(t, [0] * ndims,
                                 tf.cast(new_shape, dtype=tf.int64))
        elif isinstance(t, tf.RaggedTensor):
          return t[:, :new_shape[1], ...]
        else:
          return tf.slice(t, [0] * ndims, new_shape)

      def pad_fn(k=k,
                 t=t,
                 ndims=ndims,
                 num_frames=num_frames,
                 new_shape=new_shape):
        """Pads the tensor."""
        if isinstance(t, tf.sparse.SparseTensor):
          return tf.sparse.reset_shape(t, new_shape)
        elif isinstance(t, tf.RaggedTensor):
          # Convert ragged to a flattened sparse tensor with shape
          # [batch_size * list_size, ...].
          sparse_tensor = tf.sparse.reset_shape(t.to_sparse(), new_shape)
          flattened_shape = tf.concat(
              [[new_shape[0] * list_size], new_shape[2:]], axis=0)
          flattened_sparse_tensor = tf.sparse.reshape(sparse_tensor,
                                                      flattened_shape)
          # Convert to a ragged tensor of shape [batch_size * list_size, ...].
          ragged_tensor = tf.RaggedTensor.from_sparse(
              flattened_sparse_tensor, row_splits_dtype=t.row_splits.dtype)
          # Reshape ragged tensor to [batch_size, list_size, ...]
          return tf.RaggedTensor.from_uniform_row_length(
              ragged_tensor, list_size)
        else:
          # Paddings has shape [n, 2] where n is the rank of the tensor.
          paddings = tf.stack([[0, 0], [0, list_size - num_frames]] + [[0, 0]] *
                              (ndims - 2))
          return tf.pad(
              tensor=t, paddings=paddings, constant_values=padding_values[k])

      tensor = tf.cond(
          pred=num_frames > list_size, true_fn=truncate_fn, false_fn=pad_fn)
      # Infer static shape for Tensor. Set the 2nd dim to None and set_shape
      # merges `static_shape` with the existing static shape of the tensor.
      if not isinstance(tensor, (tf.sparse.SparseTensor, tf.RaggedTensor)):
        static_shape = t.get_shape().as_list()
        static_shape[1] = list_size_arg
        tensor.set_shape(static_shape)
      features[k] = tensor

    example_sizes = tf.stack(list(sizes.values()), axis=1)
    sizes = tf.reduce_max(input_tensor=example_sizes, axis=1)

    if self._size_feature_name:
      features[self._size_feature_name] = sizes
    if self._mask_feature_name:
      features[self._mask_feature_name] = tf.sequence_mask(sizes, list_size)
    return features