in tensorflow_gnn/graph/graph_tensor_io.py [0:0]
def get_io_spec(spec: gt.GraphTensorSpec,
prefix: Optional[str] = None,
validate: bool = False) -> Dict[str, IOFeature]:
"""Returns tf.io parsing features for GraphTensorSpec.
This function returns a mapping of `tf.train.Feature` names to configuration
objects that can be used to parse instances of `tf.train.Example` (see
https://www.tensorflow.org/api_docs/python/tf/io). The resulting mapping can
be used with `tf.io.parse_example()` for reading the individual fields of a
`GraphTensor` instance. This essentially forms our encoding of a `GraphTensor`
to a `tf.train.Example` proto.
(This is an internal function. You are not likely to be using this if you're
decoding graph tensors. Instead, you should use the `gnn.parse_example()`
routine directly, which handles this process for you.)
Args:
spec: A graph tensor type specification.
prefix: An optional prefix string over all the features. You may use
this if you are encoding other data in the same protocol buffer.
validate: A boolean indicating whether or not to validate that the input
fields form a valid GraphTensor. Defaults to True.
Returns:
A dict of `tf.train.Feature` name to feature configuration object, to be
used in `tf.io.parse_example()`.
"""
def get_io_ragged_partitions(
fname: str, shape: tf.TensorShape) -> Tuple[RaggedPartition, ...]:
partitions = []
for i, dim in enumerate(shape.as_list()[1:], start=1):
# pytype: disable=attribute-error
if dim is None:
partitions.append(tf.io.RaggedFeature.RowLengths(f'{fname}.d{i}'))
else:
partitions.append(tf.io.RaggedFeature.UniformRowLength(dim))
# pytype: enable=attribute-error
return tuple(partitions)
def get_io_feature(fname: str, value_spec: gt.FieldSpec) -> IOFeature:
io_dtype = _get_io_type(value_spec.dtype)
if isinstance(value_spec, tf.RaggedTensorSpec):
return tf.io.RaggedFeature(
value_key=fname,
dtype=io_dtype,
partitions=get_io_ragged_partitions(fname, value_spec.shape),
row_splits_dtype=value_spec.row_splits_dtype,
validate=validate)
if isinstance(value_spec, tf.TensorSpec):
if None not in value_spec.shape.as_list():
# If shape is [d0..dk], where di is static (compile-time constant), the
# value is parsed as a dense tensor.
return tf.io.FixedLenFeature(
dtype=io_dtype,
shape=value_spec.shape,
default_value=tf.zeros(value_spec.shape, io_dtype)
if _is_size_field(fname) else None)
if value_spec.shape[1:].is_fully_defined():
# If shape is [None, d1..dk], where di is static (compile-time
# constant), the value is parsed as a ragged tensor with ragged rank 0.
# For single example parsing this result in a dense tensor, for multiple
# examples parsing - in ragged.
partitions = get_io_ragged_partitions(fname, value_spec.shape)
# pytype: disable=attribute-error
assert all(
isinstance(p, tf.io.RaggedFeature.UniformRowLength)
for p in partitions)
# pytype: enable=attribute-error
return tf.io.RaggedFeature(
value_key=fname,
dtype=io_dtype,
partitions=partitions,
row_splits_dtype=spec.indices_dtype,
validate=validate)
raise ValueError(
('Expected dense tensor with static non-leading dimensions'
f', got shape={value_spec.shape}, fname={fname}'))
raise ValueError(
f'Unsupported type spec {type(value_spec).__name__}, fname={fname}')
out = {}
for fname, value_spec in _flatten_graph_field_specs(spec, '').items():
if prefix:
fname = f'{prefix}{fname}'
# pylint: disable=protected-access
out[fname] = get_io_feature(
fname, gp._box_spec(spec.rank, value_spec, spec.indices_dtype))
return out