in training/flax/distil_whisper/modeling_flax_whisper.py [0:0]
def setup(self) -> None:
self.embed_dim = self.config.d_model
self.self_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
causal=True,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
self.activation_fn = ACT2FN[self.config.activation_function]
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
self.encoder_attn = FlaxWhisperAttention(
config=self.config,
embed_dim=self.embed_dim,
num_heads=self.config.decoder_attention_heads,
dropout=self.config.attention_dropout,
dtype=self.dtype,
params_dtype=self.params_dtype,
)
self.encoder_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
self.fc1 = DenseGeneral(
self.config.decoder_ffn_dim,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("embed", "mlp"),
)
self.fc2 = DenseGeneral(
self.embed_dim,
dtype=self.dtype,
params_dtype=self.params_dtype,
kernel_axes=("mlp", "embed"),
)
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)