def __init__()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def __init__(self, force_token_map):
        # The generic `transformers` logit processor builds `force_token_array` as a dictionary - this is not a valid
        # JAX type, and so we switch to using a JAX array instead
        force_token_map = jnp.array(force_token_map)
        # Converts the array of format [[index, token]] containing the tokens to be forced to an array, where the
        # index of the array corresponds to the index of the token to be forced. For XLA compatibility,
        # indexes without forced tokens will have a negative value. Note that the last token we ever need to force in
        # Whisper is at position 3, so we only construct an array up to this index. The native version constructs a tensor
        # dynamically according to the length of the `force_token_map`. Array shapes need to be concrete for XLA compatibility,
        # so this is not permitted here.
        force_token_array = jnp.ones(3, dtype=jnp.int32) * -1
        for index, token in force_token_map:
            force_token_array = force_token_array.at[index].set(token)
        self.force_token_array = jnp.int32(force_token_array)