in training/flax/distil_whisper/partitioner.py [0:0]
def __init__(self, global_mesh: Mesh):
self.global_mesh = global_mesh
local_mesh = global_mesh.local_mesh
first_local_device = local_mesh.devices.reshape(-1)[0]
host_location = collections.OrderedDict(
zip(
global_mesh.shape.keys(),
list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0],
)
)
self.num_chunks = collections.OrderedDict()
self.chunk_ids = collections.OrderedDict()
self.mesh_axes = list(global_mesh.shape.keys())
for mesh_axis in self.mesh_axes:
num_devices_per_chunk = local_mesh.shape[mesh_axis]
self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk
self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk