in training/flax/distil_whisper/partitioner.py [0:0]
def get_gpu_mesh(num_partitions: int) -> Mesh:
"""Mesh for GPUs that preferentially places 'model' on NVLink."""
nvlink_size = jax.local_device_count()
dcn_size = jax.process_count()
nvlink_mp = min(num_partitions, nvlink_size)
nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp)
dcn_mp, extra2 = divmod(num_partitions, nvlink_mp)
assert not (
extra1 or extra2
), "number of partitions on GPU must be a factor or multiple of the number of local devices"
dcn_dp = dcn_size // dcn_mp
devices = create_hybrid_device_mesh(
mesh_shape=[nvlink_dp, nvlink_mp],
dcn_mesh_shape=[dcn_dp, dcn_mp],
process_is_granule=True,
)
global_mesh = Mesh(devices, ["data", "model"])
logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
logging.info("global_mesh devices: %s", global_mesh.devices)
return global_mesh