def _create_empty_value()

in tensorflow_gnn/graph/graph_piece.py [0:0]


  def _create_empty_value(self) -> GraphPieceBase:
    """Creates minimal empty GraphPiece allowed by this spec.

    Rules:
      1. all unknown dimensions are assumed to be 0.
      2. field values for fixed size dimensions are set to empty with tf.zeros.
      3. resulting tensor should have no values (empty values of flat values).

    NOTE: this is temporary workaround to allow to contruct GraphTensors with
    empty batch dimensions to use with TF distribution strategy (b/183969859).
    The method could be removed in the future without notice, PLEASE DO NOT USE.

    Returns:
      GraphPiece compatible with this spec.
    """

    def create_empty_dense_field(shape: tf.TensorShape,
                                 dtype: tf.dtypes.DType) -> tf.Tensor:
      dims = [(0 if d is None else d) for d in shape.as_list()]
      if 0 not in dims:
        raise ValueError(
            f'Could not create empty tensor for non-empty shape {shape}')
      return tf.zeros(tf.TensorShape(dims), dtype)

    def create_empty_ragged_field(spec: tf.RaggedTensorSpec) -> Field:
      if spec.value_type == tf.Tensor:
        # For ragged rank-0 tensors values are dense tensors.
        return create_empty_dense_field(spec.shape, spec.dtype)

      assert spec.value_type == tf.RaggedTensor
      assert spec.ragged_rank > 0

      # Set components dimension to 0 (the outer-most flat values dimension).
      flat_values_shape = spec.shape[spec.ragged_rank:]
      assert flat_values_shape[1:].is_fully_defined(), flat_values_shape
      if flat_values_shape[0] not in (None, 0):
        raise ValueError(f'Could not create empty flat values for {spec}')

      # Use empty tensors for ragged dimensions row splits. Keep uniform
      # dimensions unchaged.
      empty_row_splits = tf.constant([0], dtype=spec.row_splits_dtype)
      result = create_empty_dense_field(flat_values_shape, spec.dtype)
      for dim in reversed(spec.shape[1:(spec.ragged_rank + 1)].as_list()):
        if dim is None:
          result = tf.RaggedTensor.from_row_splits(
              result,
              empty_row_splits,
              validate=const.validate_internal_results)
        else:
          result = tf.RaggedTensor.from_uniform_row_length(
              result,
              tf.convert_to_tensor(dim, dtype=spec.row_splits_dtype),
              validate=const.validate_internal_results)
      return result

    def create_empty_field(spec):
      if isinstance(spec, GraphPieceSpecBase):
        return spec._create_empty_value()  # pylint: disable=protected-access

      if isinstance(spec, tf.RaggedTensorSpec):
        return create_empty_ragged_field(cast(tf.RaggedTensorSpec, spec))

      if isinstance(spec, tf.TensorSpec):
        return create_empty_dense_field(spec.shape, spec.dtype)

      raise ValueError(f'Unsupported field type {type(spec).__name__}')

    dummy_fields = tf.nest.map_structure(create_empty_field, self._data_spec)

    cls = self.value_type
    assert issubclass(cls, GraphPieceBase), cls
    result = self.value_type(dummy_fields, self)
    if const.validate_internal_results:
      assert self.is_compatible_with(result)
    return result