def __init__()

in training/flax/distil_whisper/partitioner.py [0:0]


    def __init__(self, global_mesh: Mesh):
        self.global_mesh = global_mesh
        local_mesh = global_mesh.local_mesh
        first_local_device = local_mesh.devices.reshape(-1)[0]
        host_location = collections.OrderedDict(
            zip(
                global_mesh.shape.keys(),
                list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0],
            )
        )
        self.num_chunks = collections.OrderedDict()
        self.chunk_ids = collections.OrderedDict()
        self.mesh_axes = list(global_mesh.shape.keys())
        for mesh_axis in self.mesh_axes:
            num_devices_per_chunk = local_mesh.shape[mesh_axis]
            self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk
            self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk