def __call__()

in training/flax/distil_whisper/layers.py [0:0]


    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along multiple dimensions.

        Args:
          inputs: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        features = _canonicalize_tuple(self.features)
        axis = _canonicalize_tuple(self.axis)

        inputs = jnp.asarray(inputs, self.dtype)
        axis = _normalize_axes(axis, inputs.ndim)

        kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
        kernel_in_axis = np.arange(len(axis))
        kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
        kernel = param_with_axes(
            "kernel",
            self.kernel_init,
            kernel_shape,
            self.params_dtype,
            kernel_in_axis,
            kernel_out_axis,
            axes=self.kernel_axes,
        )
        if self.use_bias:
            bias = param_with_axes(
                "bias",
                self.bias_init,
                features,
                self.params_dtype,
                axes=(self.kernel_axes[-1],),
            )
        kernel = jnp.asarray(kernel, self.dtype)

        contract_ind = tuple(range(0, len(axis)))
        y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
        if self.use_bias:
            bias = jnp.asarray(bias, self.dtype)
            # y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
            y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:])
        return y