in parler_tts/modeling_parler_tts.py [0:0]
def __init__(self, config: ParlerTTSDecoderConfig):
super().__init__(config)
self.model = ParlerTTSModel(config)
self.num_codebooks = config.num_codebooks
self.vocab_size = config.vocab_size
self.num_codebooks = config.num_codebooks
self.use_fused_lm_heads = config.use_fused_lm_heads
if self.use_fused_lm_heads:
self.lm_heads = nn.Linear(config.hidden_size, config.vocab_size * config.num_codebooks, bias=False)
else:
self.lm_heads = nn.ModuleList(
[nn.Linear(config.hidden_size, config.vocab_size, bias=False) for _ in range(config.num_codebooks)]
)
# Initialize weights and apply final processing
self.post_init()