def ComputationShape()

in lingvo/core/py_utils.py [0:0]


def ComputationShape(split_size, topology=None):
  """Decides the computation shape based on the split_size.

  Args:
    split_size: number of accelerators to use per split.
    topology: a serialized string of `tensorflow.tpu.TopologyProto`, or a
      `tf.tpu.experimental.Topology` object, that describes the TPU cluster
      topology. If not set, it'll use a default setting based on split_size.

  Returns:
    A 4-element list that describes the computation shape.
  """
  if topology:
    if isinstance(topology, tf.tpu.experimental.Topology):
      topology_info = topology
    else:
      topology_info = tf_topology.Topology(serialized=topology)
  computation_shape = None
  if topology and functools.reduce(lambda a, b: a * b,
                                   topology_info.mesh_shape) == split_size:
    computation_shape = topology_info.mesh_shape
  elif split_size == 1:
    computation_shape = [1, 1, 1, 1]
  elif topology and topology_info.mesh_shape[
      -1] == 1 and split_size in topology_info.mesh_shape:
    # For Megacore, if we find exact match on mesh shape, map split_size to it
    computation_shape = [1, 1, 1, 1]
    computation_shape[topology_info.mesh_shape.tolist().index(
        split_size)] = split_size
  else:
    if topology:
      cores_per_chip = topology_info.mesh_shape[-1]
    else:
      cores_per_chip = 2
    assert split_size % cores_per_chip == 0
    split_chips = split_size // cores_per_chip
    if split_chips == 1:
      computation_shape = [1, 1, 1, cores_per_chip]
    elif split_chips == 2:
      computation_shape = [1, 2, 1, cores_per_chip]
    elif split_chips == 4:
      computation_shape = [2, 2, 1, cores_per_chip]
    elif split_chips == 8:
      computation_shape = [4, 2, 1, cores_per_chip]
    elif split_chips == 12:
      computation_shape = [1, 1, 12, cores_per_chip]
    elif split_chips == 16:
      computation_shape = [4, 4, 1, cores_per_chip]
    elif split_chips == 24:
      computation_shape = [1, 2, 12, cores_per_chip]
    elif split_chips == 32:
      if topology and topology_info.mesh_shape[1] == 32:
        # Fwd within-replica all-reduces is performed along column;
        # Bwd gradient cross-replica all-reduces is performed along row.
        # This currently has better performance than the strided patten.
        computation_shape = [1, 32, 1, cores_per_chip]
      else:
        computation_shape = [4, 8, 1, cores_per_chip]
    elif split_chips == 64:
      computation_shape = [8, 8, 1, cores_per_chip]
    elif split_chips == 128:
      computation_shape = [8, 16, 1, cores_per_chip]
    elif split_chips == 256:
      computation_shape = [16, 16, 1, cores_per_chip]
    elif split_chips == 512:
      computation_shape = [16, 32, 1, cores_per_chip]
    elif split_chips == 1024:
      computation_shape = [32, 32, 1, cores_per_chip]
    elif split_chips == 2048:
      computation_shape = [64, 32, 1, cores_per_chip]
    elif split_chips == 4096:
      computation_shape = [128, 32, 1, cores_per_chip]
    else:
      assert False, ('Model parallelism with %d devices is currently not'
                     ' supported.' % split_size)
  assert computation_shape is not None
  return computation_shape