in training/flax/distil_whisper/pipeline.py [0:0]
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right
all_chunk_start_idx = np.arange(0, inputs_len, step)
num_samples = len(all_chunk_start_idx)
num_batches = math.ceil(num_samples / batch_size)
batch_idx = np.array_split(np.arange(num_samples), num_batches)
for idx in batch_idx:
chunk_start_idx = all_chunk_start_idx[idx]
chunk_end_idx = chunk_start_idx + chunk_len
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
processed = self.feature_extractor(
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
)
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
_stride_right = np.where(is_last, 0, stride_right)
chunk_lens = [chunk.shape[0] for chunk in chunks]
strides = [
(chunk_l, _stride_l, _stride_r)
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
]
yield {"stride": strides, **processed}