def get_gpu_mesh()

in training/flax/distil_whisper/partitioner.py [0:0]


def get_gpu_mesh(num_partitions: int) -> Mesh:
    """Mesh for GPUs that preferentially places 'model' on NVLink."""
    nvlink_size = jax.local_device_count()
    dcn_size = jax.process_count()
    nvlink_mp = min(num_partitions, nvlink_size)
    nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp)
    dcn_mp, extra2 = divmod(num_partitions, nvlink_mp)
    assert not (
        extra1 or extra2
    ), "number of partitions on GPU must be a factor or multiple of the number of local devices"
    dcn_dp = dcn_size // dcn_mp

    devices = create_hybrid_device_mesh(
        mesh_shape=[nvlink_dp, nvlink_mp],
        dcn_mesh_shape=[dcn_dp, dcn_mp],
        process_is_granule=True,
    )

    global_mesh = Mesh(devices, ["data", "model"])
    logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
    logging.info("global_mesh devices: %s", global_mesh.devices)
    return global_mesh