in lingvo/core/py_utils.py [0:0]
def ComputationShape(split_size, topology=None):
"""Decides the computation shape based on the split_size.
Args:
split_size: number of accelerators to use per split.
topology: a serialized string of `tensorflow.tpu.TopologyProto`, or a
`tf.tpu.experimental.Topology` object, that describes the TPU cluster
topology. If not set, it'll use a default setting based on split_size.
Returns:
A 4-element list that describes the computation shape.
"""
if topology:
if isinstance(topology, tf.tpu.experimental.Topology):
topology_info = topology
else:
topology_info = tf_topology.Topology(serialized=topology)
computation_shape = None
if topology and functools.reduce(lambda a, b: a * b,
topology_info.mesh_shape) == split_size:
computation_shape = topology_info.mesh_shape
elif split_size == 1:
computation_shape = [1, 1, 1, 1]
elif topology and topology_info.mesh_shape[
-1] == 1 and split_size in topology_info.mesh_shape:
# For Megacore, if we find exact match on mesh shape, map split_size to it
computation_shape = [1, 1, 1, 1]
computation_shape[topology_info.mesh_shape.tolist().index(
split_size)] = split_size
else:
if topology:
cores_per_chip = topology_info.mesh_shape[-1]
else:
cores_per_chip = 2
assert split_size % cores_per_chip == 0
split_chips = split_size // cores_per_chip
if split_chips == 1:
computation_shape = [1, 1, 1, cores_per_chip]
elif split_chips == 2:
computation_shape = [1, 2, 1, cores_per_chip]
elif split_chips == 4:
computation_shape = [2, 2, 1, cores_per_chip]
elif split_chips == 8:
computation_shape = [4, 2, 1, cores_per_chip]
elif split_chips == 12:
computation_shape = [1, 1, 12, cores_per_chip]
elif split_chips == 16:
computation_shape = [4, 4, 1, cores_per_chip]
elif split_chips == 24:
computation_shape = [1, 2, 12, cores_per_chip]
elif split_chips == 32:
if topology and topology_info.mesh_shape[1] == 32:
# Fwd within-replica all-reduces is performed along column;
# Bwd gradient cross-replica all-reduces is performed along row.
# This currently has better performance than the strided patten.
computation_shape = [1, 32, 1, cores_per_chip]
else:
computation_shape = [4, 8, 1, cores_per_chip]
elif split_chips == 64:
computation_shape = [8, 8, 1, cores_per_chip]
elif split_chips == 128:
computation_shape = [8, 16, 1, cores_per_chip]
elif split_chips == 256:
computation_shape = [16, 16, 1, cores_per_chip]
elif split_chips == 512:
computation_shape = [16, 32, 1, cores_per_chip]
elif split_chips == 1024:
computation_shape = [32, 32, 1, cores_per_chip]
elif split_chips == 2048:
computation_shape = [64, 32, 1, cores_per_chip]
elif split_chips == 4096:
computation_shape = [128, 32, 1, cores_per_chip]
else:
assert False, ('Model parallelism with %d devices is currently not'
' supported.' % split_size)
assert computation_shape is not None
return computation_shape