def __init__()

in parler_tts/modeling_parler_tts.py [0:0]


    def __init__(self, config: ParlerTTSDecoderConfig):
        super().__init__(config)
        self.dropout = config.dropout
        self.layerdrop = config.layerdrop
        self.max_target_positions = config.max_position_embeddings
        self.d_model = config.hidden_size
        self.num_codebooks = config.num_codebooks
        self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0

        # TODO(YL): actually doesn't need the +1 if initialized correctly. Too late to change now.
        embed_dim = config.vocab_size + 1  # + 1 for pad token id
        self.embed_tokens = nn.ModuleList(
            [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
        )

        self.rope_embeddings = config.rope_embeddings
        if not config.rope_embeddings:
            self.embed_positions = ParlerTTSSinusoidalPositionalEmbedding(
                config.max_position_embeddings,
                config.hidden_size,
            )
        else:
            self.rotary_emb = ParlerTTSRotaryEmbedding(
                config.hidden_size // config.num_attention_heads,
                max_position_embeddings=config.max_position_embeddings,
                base=config.rope_theta,
            )
        self.layers = nn.ModuleList(
            [ParlerTTSDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.attn_implementation = config._attn_implementation
        encoder_attn_implementation = config._attn_implementation
        if config.cross_attention_implementation_strategy is not None:
            encoder_attn_implementation = (
                "sdpa" if config.cross_attention_implementation_strategy == "always_sdpa" else "eager"
            )
        self.encoder_attn_implementation = encoder_attn_implementation
        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()