def _get_batch_size_from_input_shapes()

in tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py [0:0]


def _get_batch_size_from_input_shapes(input_shape):
  """From a list of input shapes, gets the per core size.

  We want to extract the first dimension for each TensorShape in input_shape and
  ensure that:
  1. They are all the same or None.
  2. They are not all None.

  If SparseTensors are fed directly into call (which in turn calls build),
  during tracing, the shape of the SparseTensor will itself be a tensor which
  results in unknown dimensions for the SparseTensor.

  Args:
    input_shape: A nested structure of `TensorShape`s.

  Returns:
    The per core batch size.
  """
  flattened_input_shape = tf.nest.flatten(input_shape)
  if not flattened_input_shape:
    raise ValueError("No input passed to TPUEmbedding layer.")

  per_core_batch_size = None

  for tensor_shape in flattened_input_shape:
    if tensor_shape.rank < 1:
      raise ValueError(
          "Received input tensor of shape {}. Rank must be > 0.".format(
              tensor_shape))
    shape = tensor_shape.as_list()
    if shape[0] is not None:
      if per_core_batch_size is None:
        per_core_batch_size = shape[0]
      elif per_core_batch_size != shape[0]:
        raise ValueError("Found multiple batch sizes {} and {} in input. All "
                         "features must have the same batch size.".format(
                             per_core_batch_size, shape[0]))

  if per_core_batch_size is None:
    raise ValueError("Unable to determine batch dimension of any features. "
                     "This may happen if all inputs are SparseTensors. If you "
                     "are using tf.keras.Input, you must specify a batch size."
                     "If you are using the layer in a subclass model and "
                     "calling the layer directly/using build(), you must "
                     "specify a batch size when constructing the layer.")

  return per_core_batch_size