def constant_value_as_shape()

in tensorflow/tensorflow/python/framework/tensor_util.py [0:0]


def constant_value_as_shape(tensor):  # pylint: disable=invalid-name
  """A version of `constant_value()` that returns a `TensorShape`.

  This version should be used when a constant tensor value is
  interpreted as a (possibly partial) shape, e.g. in the shape
  function for `tf.reshape()`. By explicitly requesting a
  `TensorShape` as the return value, it is possible to represent
  unknown dimensions; by contrast, `constant_value()` is
  all-or-nothing.

  Args:
    tensor: The rank-0 or rank-1 Tensor to be evaluated.

  Returns:
    A `TensorShape` based on the constant value of the given `tensor`.

  Raises:
    ValueError: If the shape is rank-0 and is not statically known to be -1.
  """
  if isinstance(tensor, ops.EagerTensor):
    return tensor_shape.as_shape(
        [dim if dim != -1 else None for dim in tensor.numpy()])

  if tensor.get_shape().ndims == 0:
    value = constant_value(tensor)
    if value is None:
      raise ValueError(
          "Received a scalar with unknown value as shape; require a statically "
          "known scalar with value '-1' to describe an unknown shape.")
    if value != -1:
      raise ValueError(
          "Received a scalar value '%s' as shape; require a statically known "
          "scalar with value '-1' to describe an unknown shape." % value)
    return tensor_shape.unknown_shape()

  shape = tensor.get_shape().with_rank(1)
  if shape == [0]:
    return tensor_shape.TensorShape([])
  elif tensor.op.type == "Shape":
    return tensor.op.inputs[0].get_shape()
  elif tensor.op.type == "Pack":
    ret = tensor_shape.TensorShape([])  # Empty list.
    # Since we expect rank 1 inputs, Pack's axis must be zero, otherwise it
    # would not be rank 1.
    assert tensor.op.get_attr("axis") == 0
    for pack_input in tensor.op.inputs:
      # `pack_input` must be a scalar. Attempt to evaluate it, and append it
      # to `ret`.
      pack_input_val = constant_value(pack_input)
      if pack_input_val is None or pack_input_val < 0:
        new_dim = tensor_shape.Dimension(None)
      else:
        new_dim = tensor_shape.Dimension(pack_input_val)
      ret = ret.concatenate([new_dim])
    return ret
  elif tensor.op.type == "Concat":
    # We assume that `tensor.op.inputs[0]` evaluates to 0, as this is
    # the only legal value when concatenating vectors, and it will
    # have been checked by a previous shape function.
    ret = tensor_shape.TensorShape([])  # Empty list.
    for concat_input in tensor.op.inputs[1:]:
      # `concat_input` must be a vector. Attempt to evaluate it as a shape,
      # and concatenate it with `ret`.
      ret = ret.concatenate(constant_value_as_shape(concat_input))
    return ret
  elif tensor.op.type == "ConcatV2":
    # We assume that `tensor.op.inputs[-1]` evaluates to 0, as this is
    # the only legal value when concatenating vectors, and it will
    # have been checked by a previous shape function.
    ret = tensor_shape.TensorShape([])  # Empty list.
    for concat_input in tensor.op.inputs[:-1]:
      # `concat_input` must be a vector. Attempt to evaluate it as a shape,
      # and concatenate it with `ret`.
      ret = ret.concatenate(constant_value_as_shape(concat_input))
    return ret
  elif tensor.op.type == "StridedSlice":
    try:
      begin = constant_value(tensor.op.inputs[1])
      end = constant_value(tensor.op.inputs[2])
      strides = constant_value(tensor.op.inputs[3])
      if begin is not None and end is not None and strides is not None:
        begin = begin[0]
        end = end[0]
        strides = strides[0]
        begin_mask = tensor.op.get_attr("begin_mask")
        if begin_mask == 1:
          begin = None
        end_mask = tensor.op.get_attr("end_mask")
        if end_mask == 1:
          end = None

        ellipsis_mask = tensor.op.get_attr("ellipsis_mask")
        new_axis_mask = tensor.op.get_attr("new_axis_mask")
        shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask")
        valid_attributes = (not ellipsis_mask and not new_axis_mask and
                            not shrink_axis_mask and (not begin_mask or
                                                      (begin_mask == 1)) and
                            (not end_mask or (end_mask == 1)))
        if valid_attributes:  # additional inputs not supported
          prev = constant_value_as_shape(tensor.op.inputs[0])
          prev = prev[begin:end:strides]
          ret = tensor_shape.TensorShape(prev)
          return ret

    except ValueError:  # Could come from get_attr or slicing prev.
      pass
    except TypeError:  # Could come from slicing prev.
      pass
  elif (tensor.op.type == "Placeholder" and
        tensor.op.graph.building_function and
        hasattr(tensor.op.graph, "internal_captures")):
    # If we are inside a FuncGraph try to lookup the constant value of the
    # corresponding external capture. Note that we only look at captures and
    # not the fed inputs because those can be fed different values in different
    # instantiations of the function call or different iterations of a
    # tf.while_loop.
    for i, capture in enumerate(tensor.op.graph.internal_captures):
      if capture is tensor:
        external_capture = tensor.op.graph.external_captures[i]
        return constant_value_as_shape(external_capture)

  ret = tensor_shape.unknown_shape(shape.dims[0].value)
  value = constant_value(tensor)
  if value is not None:
    ret = ret.merge_with(
        tensor_shape.TensorShape([d if d >= 0 else None for d in value]))
  return ret