in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]
def __init__(self, force_token_map):
# The generic `transformers` logit processor builds `force_token_array` as a dictionary - this is not a valid
# JAX type, and so we switch to using a JAX array instead
force_token_map = jnp.array(force_token_map)
# Converts the array of format [[index, token]] containing the tokens to be forced to an array, where the
# index of the array corresponds to the index of the token to be forced. For XLA compatibility,
# indexes without forced tokens will have a negative value. Note that the last token we ever need to force in
# Whisper is at position 3, so we only construct an array up to this index. The native version constructs a tensor
# dynamically according to the length of the `force_token_map`. Array shapes need to be concrete for XLA compatibility,
# so this is not permitted here.
force_token_array = jnp.ones(3, dtype=jnp.int32) * -1
for index, token in force_token_map:
force_token_array = force_token_array.at[index].set(token)
self.force_token_array = jnp.int32(force_token_array)