in training/flax/distil_whisper/layers.py [0:0]
def __call__(self, inputs: Array) -> Array:
"""Applies a linear transformation to the inputs along multiple dimensions.
Args:
inputs: The nd-array to be transformed.
Returns:
The transformed input.
"""
features = _canonicalize_tuple(self.features)
axis = _canonicalize_tuple(self.axis)
inputs = jnp.asarray(inputs, self.dtype)
axis = _normalize_axes(axis, inputs.ndim)
kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
kernel_in_axis = np.arange(len(axis))
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
kernel = param_with_axes(
"kernel",
self.kernel_init,
kernel_shape,
self.params_dtype,
kernel_in_axis,
kernel_out_axis,
axes=self.kernel_axes,
)
if self.use_bias:
bias = param_with_axes(
"bias",
self.bias_init,
features,
self.params_dtype,
axes=(self.kernel_axes[-1],),
)
kernel = jnp.asarray(kernel, self.dtype)
contract_ind = tuple(range(0, len(axis)))
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
if self.use_bias:
bias = jnp.asarray(bias, self.dtype)
# y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:])
return y