def setup()

in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]


    def setup(self) -> None:
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                "embed_dim must be divisible by num_heads (got `embed_dim`:"
                f" {self.embed_dim} and `num_heads`: {self.num_heads})."
            )

        dense = partial(
            DenseGeneral,
            self.embed_dim,
            axis=-1,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
            kernel_axes=("embed", "joined_kv"),
        )

        self.q_proj = dense(use_bias=self.bias)
        self.k_proj = dense(use_bias=False)
        self.v_proj = dense(use_bias=self.bias)

        self.out_proj = DenseGeneral(
            self.embed_dim,
            axis=-1,
            dtype=self.dtype,
            params_dtype=self.params_dtype,
            kernel_axes=("joined_kv", "embed"),
            use_bias=self.bias,
        )

        if self.causal:
            self.causal_mask = make_causal_mask(
                jnp.ones((1, self.config.max_target_positions), dtype="bool"),
                dtype="bool",
            )