def infer()

in torchaudio/models/tacotron2.py [0:0]


    def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        """Decoder inference

        Args:
            memory (Tensor): Encoder outputs
                with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
            memory_lengths (Tensor): Encoder output lengths for attention masking
                (the same as ``text_lengths``) with shape (n_batch, ).

        Returns:
            mel_specgram (Tensor): Predicted mel spectrogram
                with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
            mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
            gate_outputs (Tensor): Predicted stop token for each timestep
                with shape (n_batch,  max of ``mel_specgram_lengths``).
            alignments (Tensor): Sequence of attention weights from the decoder
                with shape (n_batch,  max of ``mel_specgram_lengths``, max of ``text_lengths``).
        """
        batch_size, device = memory.size(0), memory.device

        decoder_input = self._get_go_frame(memory)

        mask = _get_mask_from_lengths(memory_lengths)
        (
            attention_hidden,
            attention_cell,
            decoder_hidden,
            decoder_cell,
            attention_weights,
            attention_weights_cum,
            attention_context,
            processed_memory,
        ) = self._initialize_decoder_states(memory)

        mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device)
        finished = torch.zeros([batch_size], dtype=torch.bool, device=device)
        mel_specgrams: List[Tensor] = []
        gate_outputs: List[Tensor] = []
        alignments: List[Tensor] = []
        for _ in range(self.decoder_max_step):
            decoder_input = self.prenet(decoder_input)
            (
                mel_specgram,
                gate_output,
                attention_hidden,
                attention_cell,
                decoder_hidden,
                decoder_cell,
                attention_weights,
                attention_weights_cum,
                attention_context,
            ) = self.decode(
                decoder_input,
                attention_hidden,
                attention_cell,
                decoder_hidden,
                decoder_cell,
                attention_weights,
                attention_weights_cum,
                attention_context,
                memory,
                processed_memory,
                mask,
            )

            mel_specgrams.append(mel_specgram.unsqueeze(0))
            gate_outputs.append(gate_output.transpose(0, 1))
            alignments.append(attention_weights)
            mel_specgram_lengths[~finished] += 1

            finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold
            if self.decoder_early_stopping and torch.all(finished):
                break

            decoder_input = mel_specgram

        if len(mel_specgrams) == self.decoder_max_step:
            warnings.warn(
                "Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript."
            )

        mel_specgrams = torch.cat(mel_specgrams, dim=0)
        gate_outputs = torch.cat(gate_outputs, dim=0)
        alignments = torch.cat(alignments, dim=0)

        mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments)

        return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments