def __call__()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
        def _force_token(generation_idx):
            batch_size = scores.shape[0]
            current_token = self.force_token_array[generation_idx]

            new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
            updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
            new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
            return new_scores

        scores = lax.cond(
            cur_len >= self.force_token_array.shape[0],
            # If the current length is geq than the length of force_token_array, the processor does nothing.
            lambda: scores,
            # Otherwise, it may force a certain token.
            lambda: lax.cond(
                self.force_token_array[cur_len] >= 0,
                # Only valid (positive) tokens are forced
                lambda: _force_token(cur_len),
                # Otherwise, the processor does nothing.
                lambda: scores,
            ),
        )
        return scores