in parler_tts/modeling_parler_tts.py [0:0]
def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs):
prompt_hidden_states = self.embed_prompts(prompt_input_ids)
if self.prompt_cross_attention:
# add sinusoidal positional embedding
positions = self.embed_positions(prompt_hidden_states, 0)
prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device)
attention_mask = model_kwargs.get("attention_mask", None)
prompt_attention_mask = model_kwargs.get("prompt_attention_mask", None)
encoder_hidden_states = model_kwargs["encoder_outputs"].last_hidden_state
if prompt_attention_mask is not None and attention_mask is None:
attention_mask = torch.ones(
encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype
)
elif attention_mask is not None and prompt_attention_mask is None:
prompt_attention_mask = torch.ones(
prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype
)
# concatenate text description states with prompt description states
encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1)
if prompt_attention_mask is not None:
attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1)
model_kwargs["encoder_outputs"].last_hidden_state = encoder_hidden_states
model_kwargs["attention_mask"] = attention_mask
# in this case, since we already concatenated the prompt hidden states and attention mask, we don't need them anymore.
model_kwargs["prompt_hidden_states"] = None
model_kwargs["prompt_attention_mask"] = None
else:
model_kwargs["prompt_hidden_states"] = prompt_hidden_states
# we're keeping the prompt attention mask because it has to be prepended to the decoder attention mask on the fly
return model_kwargs