in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
"embed_dim must be divisible by num_heads (got `embed_dim`:"
f" {self.embed_dim} and `num_heads`: {self.num_heads})."
)
dense = partial(
DenseGeneral,
self.embed_dim,
axis=-1,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("embed", "joined_kv"),
)
self.q_proj = dense(use_bias=self.bias)
self.k_proj = dense(use_bias=False)
self.v_proj = dense(use_bias=self.bias)
self.out_proj = DenseGeneral(
self.embed_dim,
axis=-1,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("joined_kv", "embed"),
use_bias=self.bias,
)
if self.causal:
self.causal_mask = make_causal_mask(
jnp.ones((1, self.config.max_target_positions), dtype="bool"),
dtype="bool",
)