def _prepare_prompt_kwargs_for_generation()

in parler_tts/modeling_parler_tts.py [0:0]


    def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs):
        prompt_hidden_states = self.embed_prompts(prompt_input_ids)

        if self.prompt_cross_attention:
            # add sinusoidal positional embedding
            positions = self.embed_positions(prompt_hidden_states, 0)
            prompt_hidden_states = prompt_hidden_states + positions.to(prompt_hidden_states.device)

            attention_mask = model_kwargs.get("attention_mask", None)
            prompt_attention_mask = model_kwargs.get("prompt_attention_mask", None)
            encoder_hidden_states = model_kwargs["encoder_outputs"].last_hidden_state

            if prompt_attention_mask is not None and attention_mask is None:
                attention_mask = torch.ones(
                    encoder_hidden_states.shape[:2], device=self.device, dtype=prompt_attention_mask.dtype
                )
            elif attention_mask is not None and prompt_attention_mask is None:
                prompt_attention_mask = torch.ones(
                    prompt_hidden_states.shape[:2], device=self.device, dtype=attention_mask.dtype
                )

            # concatenate text description states with prompt description states
            encoder_hidden_states = torch.cat([encoder_hidden_states, prompt_hidden_states], dim=1)
            if prompt_attention_mask is not None:
                attention_mask = torch.cat([attention_mask, prompt_attention_mask], dim=1)

            model_kwargs["encoder_outputs"].last_hidden_state = encoder_hidden_states
            model_kwargs["attention_mask"] = attention_mask

            # in this case, since we already concatenated the prompt hidden states and attention mask, we don't need them anymore.
            model_kwargs["prompt_hidden_states"] = None
            model_kwargs["prompt_attention_mask"] = None
        else:
            model_kwargs["prompt_hidden_states"] = prompt_hidden_states
            # we're keeping the prompt attention mask because it has to be prepended to the decoder attention mask on the fly
        return model_kwargs