in torchaudio/models/tacotron2.py [0:0]
def infer(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Decoder inference
Args:
memory (Tensor): Encoder outputs
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
memory_lengths (Tensor): Encoder output lengths for attention masking
(the same as ``text_lengths``) with shape (n_batch, ).
Returns:
mel_specgram (Tensor): Predicted mel spectrogram
with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
gate_outputs (Tensor): Predicted stop token for each timestep
with shape (n_batch, max of ``mel_specgram_lengths``).
alignments (Tensor): Sequence of attention weights from the decoder
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
batch_size, device = memory.size(0), memory.device
decoder_input = self._get_go_frame(memory)
mask = _get_mask_from_lengths(memory_lengths)
(
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
processed_memory,
) = self._initialize_decoder_states(memory)
mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device)
finished = torch.zeros([batch_size], dtype=torch.bool, device=device)
mel_specgrams: List[Tensor] = []
gate_outputs: List[Tensor] = []
alignments: List[Tensor] = []
for _ in range(self.decoder_max_step):
decoder_input = self.prenet(decoder_input)
(
mel_specgram,
gate_output,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
) = self.decode(
decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
memory,
processed_memory,
mask,
)
mel_specgrams.append(mel_specgram.unsqueeze(0))
gate_outputs.append(gate_output.transpose(0, 1))
alignments.append(attention_weights)
mel_specgram_lengths[~finished] += 1
finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold
if self.decoder_early_stopping and torch.all(finished):
break
decoder_input = mel_specgram
if len(mel_specgrams) == self.decoder_max_step:
warnings.warn(
"Reached max decoder steps. The generated spectrogram might not cover " "the whole transcript."
)
mel_specgrams = torch.cat(mel_specgrams, dim=0)
gate_outputs = torch.cat(gate_outputs, dim=0)
alignments = torch.cat(alignments, dim=0)
mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(mel_specgrams, gate_outputs, alignments)
return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments