in parler_tts/modeling_parler_tts.py [0:0]
def __init__(self, config: ParlerTTSDecoderConfig, layer_idx: int = None):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = PARLERTTS_ATTENTION_CLASSES[config._attn_implementation](
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
dropout=config.attention_dropout,
is_decoder=True,
is_causal=True,
bias=False,
rope_embeddings=config.rope_embeddings,
layer_idx=layer_idx,
config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
cross_attn_implementation = config._attn_implementation
if config.cross_attention_implementation_strategy == "always_eager":
cross_attn_implementation = "eager"
elif config.cross_attention_implementation_strategy == "always_sdpa":
cross_attn_implementation = "sdpa"
self.encoder_attn = PARLERTTS_ATTENTION_CLASSES[cross_attn_implementation](
self.embed_dim,
config.num_attention_heads,
num_key_value_heads=config.num_cross_attention_key_value_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
rope_embeddings=config.rope_embeddings,
layer_idx=layer_idx,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=False)
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)