def _from_data()

in tensorflow_gnn/graph/graph_piece.py [0:0]


  def _from_data(cls,
                 data: Data,
                 shape: tf.TensorShape,
                 indices_dtype: tf.dtypes.DType,
                 metadata: Metadata = None) -> 'GraphPieceBase':
    """Creates a GraphPiece from its data and attributes.

    Args:
      data: a nest of Field and GraphPiece objects. The batch dimensions of all
        Fields (incl. those nested in GraphPiece objects) must be exactly equal.
        (In particular, if a dimension is None for one, it must be None for
        all.)
      shape: A hint for the shape of the result. This shape must have a known
        rank. It must be compatible with (but not necessary equal to) the
        common batch dimensions of the Fields nested in data. (This is meant to
        align this function with the requirements of TypeSpec._from_components()
        and BatchableTypeSpec._from_compatible_tensor_list().)
      indices_dtype: indices type to use for potentially ragged fields batching.
      metadata: optional mapping from a string key to hashable values.

    Returns:
      An instance of GraphPieceBase, holding the data, after GraphPieces in the
      data has been transformed to match `indices_dtype` and `metadata`.
      The shape of the result and its constituent GraphPieces is the common
      shape of all data Fields if there are any, or else the `shape` passed
      in as an argument. In either case, the shape of the result is compatible
      with the passed-in shape (but not necessarily equal).

    Raises:
      ValueError: if the data Fields do not have equal batch shapes.
    """
    # TODO(aferludin,edloper): Clarify the requirements of
    # TypeSpec._from_components(). Why can I safely construct from components
    # with a different dynamic shape, but only if that is statically unknown?

    # pylint: disable=protected-access
    def update_fn(value: Union['GraphPieceBase', Field], shape: tf.TensorShape,
                  indices_dtype: tf.dtypes.DType,
                  metadata: Metadata) -> Union['GraphPieceBase', Field]:
      """Updates data with new attributes."""
      if isinstance(value, GraphPieceBase):
        return value._with_attributes(shape, indices_dtype, metadata)
      if not isinstance(value, (tf.RaggedTensor, tf.Tensor)):
        raise TypeError(
            f'Invalid type for: {value}; should be tensor or ragged tensor')
      return value

    shape_from_data = _get_batch_shape_from_fields(data, shape.rank)
    if shape_from_data is not None:
      if shape.is_compatible_with(shape_from_data):
        shape = shape_from_data
      else:
        raise ValueError('Fields have batch dimensions that are not compatible'
                         ' with the GraphPiece shape,'
                         f' fields batch dimensions {shape_from_data},'
                         f' GraphPiece shape {shape}')

    data = tf.nest.map_structure(
        functools.partial(
            update_fn,
            shape=shape,
            indices_dtype=indices_dtype,
            metadata=metadata), data)
    data_spec = tf.nest.map_structure(type_spec.type_spec_from_value, data)

    cls_spec = cls._type_spec_cls()
    assert issubclass(cls_spec, GraphPieceSpecBase), cls_spec
    return cls(data, cls_spec(data_spec, shape, indices_dtype, metadata))