in training/flax/distil_whisper/layers.py [0:0]
def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
"""Applies Transformer MlpBlock module."""
# Iterate over specified MLP input activation functions.
# e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
activations = []
for idx, act_fn in enumerate(self.activations):
dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
x = DenseGeneral(
self.intermediate_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=("embed", "mlp"),
name=dense_name,
)(inputs)
x = _convert_to_activation_function(act_fn)(x)
activations.append(x)
# Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations)
# Apply dropout and final dense output projection.
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
x, deterministic=deterministic
) # Broadcast along length.
x = with_sharding_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral(
inputs.shape[-1],
dtype=self.dtype,
kernel_init=self.kernel_init,
kernel_axes=("mlp", "embed"),
name="wo",
)(x)
return output