in training/flax/distil_whisper/partitioner.py [0:0]
def get_data_layout(self, batch_size: Optional[int] = None, host_index: Optional[int] = None) -> DataLayout:
"""Returns filled `DataLayout` based on the partitioned model layout.
Args:
batch_size: if set, indicates the requested batch size. The exception will
be raised if this batch size is not compatible with the layout. If not
set, the batch size is inferred from the layout.
host_index: indicates the host index to use for the calculations, if not
set - use JAX-provided one. Should be in [0, num_hosts) interval and the
order should match the order of corresponding CPU devices in
`jax.devices()`.
Returns:
Filled `DataLayout` structure.
"""
if host_index is not None:
raise NotImplementedError("Explicit host_index is not yet implemented.")
if self._data_axis is None:
return DataLayout(
batch_size=batch_size,
shard_id=0,
num_shards=1,
is_first_host_in_replica_set=(jax.process_index() == 0),
)
mesh_size = self._local_chunker.global_mesh.shape[self._data_axis]
batch_size = batch_size or mesh_size
if batch_size % mesh_size:
raise ValueError(
f"Batch size ({batch_size}) must be divisible by corresponding " f"mesh size ({mesh_size})."
)
num_shards = self._local_chunker.num_chunks[self._data_axis]
if batch_size % num_shards:
raise ValueError(f"Batch size ({batch_size}) must be divisible by number of " f"replicas ({num_shards}).")
replica_id = self._local_chunker.get_local_chunk_info((batch_size,), [self._data_axis]).replica_id
return DataLayout(
batch_size=int(batch_size),
shard_id=int(self._local_chunker.chunk_ids[self._data_axis]),
num_shards=int(num_shards),
is_first_host_in_replica_set=(replica_id == 0),
)