in tensorflow_gnn/graph/graph_piece.py [0:0]
def _from_data(cls,
data: Data,
shape: tf.TensorShape,
indices_dtype: tf.dtypes.DType,
metadata: Metadata = None) -> 'GraphPieceBase':
"""Creates a GraphPiece from its data and attributes.
Args:
data: a nest of Field and GraphPiece objects. The batch dimensions of all
Fields (incl. those nested in GraphPiece objects) must be exactly equal.
(In particular, if a dimension is None for one, it must be None for
all.)
shape: A hint for the shape of the result. This shape must have a known
rank. It must be compatible with (but not necessary equal to) the
common batch dimensions of the Fields nested in data. (This is meant to
align this function with the requirements of TypeSpec._from_components()
and BatchableTypeSpec._from_compatible_tensor_list().)
indices_dtype: indices type to use for potentially ragged fields batching.
metadata: optional mapping from a string key to hashable values.
Returns:
An instance of GraphPieceBase, holding the data, after GraphPieces in the
data has been transformed to match `indices_dtype` and `metadata`.
The shape of the result and its constituent GraphPieces is the common
shape of all data Fields if there are any, or else the `shape` passed
in as an argument. In either case, the shape of the result is compatible
with the passed-in shape (but not necessarily equal).
Raises:
ValueError: if the data Fields do not have equal batch shapes.
"""
# TODO(aferludin,edloper): Clarify the requirements of
# TypeSpec._from_components(). Why can I safely construct from components
# with a different dynamic shape, but only if that is statically unknown?
# pylint: disable=protected-access
def update_fn(value: Union['GraphPieceBase', Field], shape: tf.TensorShape,
indices_dtype: tf.dtypes.DType,
metadata: Metadata) -> Union['GraphPieceBase', Field]:
"""Updates data with new attributes."""
if isinstance(value, GraphPieceBase):
return value._with_attributes(shape, indices_dtype, metadata)
if not isinstance(value, (tf.RaggedTensor, tf.Tensor)):
raise TypeError(
f'Invalid type for: {value}; should be tensor or ragged tensor')
return value
shape_from_data = _get_batch_shape_from_fields(data, shape.rank)
if shape_from_data is not None:
if shape.is_compatible_with(shape_from_data):
shape = shape_from_data
else:
raise ValueError('Fields have batch dimensions that are not compatible'
' with the GraphPiece shape,'
f' fields batch dimensions {shape_from_data},'
f' GraphPiece shape {shape}')
data = tf.nest.map_structure(
functools.partial(
update_fn,
shape=shape,
indices_dtype=indices_dtype,
metadata=metadata), data)
data_spec = tf.nest.map_structure(type_spec.type_spec_from_value, data)
cls_spec = cls._type_spec_cls()
assert issubclass(cls_spec, GraphPieceSpecBase), cls_spec
return cls(data, cls_spec(data_spec, shape, indices_dtype, metadata))