def get_io_spec()

in tensorflow_gnn/graph/graph_tensor_io.py [0:0]


def get_io_spec(spec: gt.GraphTensorSpec,
                prefix: Optional[str] = None,
                validate: bool = False) -> Dict[str, IOFeature]:
  """Returns tf.io parsing features for GraphTensorSpec.

  This function returns a mapping of `tf.train.Feature` names to configuration
  objects that can be used to parse instances of `tf.train.Example` (see
  https://www.tensorflow.org/api_docs/python/tf/io). The resulting mapping can
  be used with `tf.io.parse_example()` for reading the individual fields of a
  `GraphTensor` instance. This essentially forms our encoding of a `GraphTensor`
  to a `tf.train.Example` proto.

  (This is an internal function. You are not likely to be using this if you're
  decoding graph tensors. Instead, you should use the `gnn.parse_example()`
  routine directly, which handles this process for you.)

  Args:
    spec: A graph tensor type specification.
    prefix: An optional prefix string over all the features. You may use
      this if you are encoding other data in the same protocol buffer.
    validate: A boolean indicating whether or not to validate that the input
      fields form a valid GraphTensor. Defaults to True.

  Returns:
    A dict of `tf.train.Feature` name to feature configuration object, to be
    used in `tf.io.parse_example()`.
  """

  def get_io_ragged_partitions(
      fname: str, shape: tf.TensorShape) -> Tuple[RaggedPartition, ...]:
    partitions = []
    for i, dim in enumerate(shape.as_list()[1:], start=1):
      # pytype: disable=attribute-error
      if dim is None:
        partitions.append(tf.io.RaggedFeature.RowLengths(f'{fname}.d{i}'))
      else:
        partitions.append(tf.io.RaggedFeature.UniformRowLength(dim))
      # pytype: enable=attribute-error
    return tuple(partitions)

  def get_io_feature(fname: str, value_spec: gt.FieldSpec) -> IOFeature:
    io_dtype = _get_io_type(value_spec.dtype)
    if isinstance(value_spec, tf.RaggedTensorSpec):
      return tf.io.RaggedFeature(
          value_key=fname,
          dtype=io_dtype,
          partitions=get_io_ragged_partitions(fname, value_spec.shape),
          row_splits_dtype=value_spec.row_splits_dtype,
          validate=validate)

    if isinstance(value_spec, tf.TensorSpec):
      if None not in value_spec.shape.as_list():
        # If shape is [d0..dk], where di is static (compile-time constant), the
        # value is parsed as a dense tensor.

        return tf.io.FixedLenFeature(
            dtype=io_dtype,
            shape=value_spec.shape,
            default_value=tf.zeros(value_spec.shape, io_dtype)
            if _is_size_field(fname) else None)

      if value_spec.shape[1:].is_fully_defined():
        # If shape is [None, d1..dk], where di is static (compile-time
        # constant), the value is parsed as a ragged tensor with ragged rank 0.
        # For single example parsing this result in a dense tensor, for multiple
        # examples parsing - in ragged.
        partitions = get_io_ragged_partitions(fname, value_spec.shape)
        # pytype: disable=attribute-error
        assert all(
            isinstance(p, tf.io.RaggedFeature.UniformRowLength)
            for p in partitions)
        # pytype: enable=attribute-error
        return tf.io.RaggedFeature(
            value_key=fname,
            dtype=io_dtype,
            partitions=partitions,
            row_splits_dtype=spec.indices_dtype,
            validate=validate)

      raise ValueError(
          ('Expected dense tensor with static non-leading dimensions'
           f', got shape={value_spec.shape}, fname={fname}'))

    raise ValueError(
        f'Unsupported type spec {type(value_spec).__name__}, fname={fname}')

  out = {}

  for fname, value_spec in _flatten_graph_field_specs(spec, '').items():
    if prefix:
      fname = f'{prefix}{fname}'
    # pylint: disable=protected-access
    out[fname] = get_io_feature(
        fname, gp._box_spec(spec.rank, value_spec, spec.indices_dtype))

  return out