def infer()

in torchaudio/models/wavernn.py [0:0]


    def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
        r"""Inference method of WaveRNN.

        This function currently only supports multinomial sampling, which assumes the
        network is trained on cross entropy loss.

        Args:
            specgram (Tensor):
                Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
            lengths (Tensor or None, optional):
                Indicates the valid length of each audio in the batch.
                Shape: `(batch, )`.
                When the ``specgram`` contains spectrograms with different durations,
                by providing ``lengths`` argument, the model will compute
                the corresponding valid output lengths.
                If ``None``, it is assumed that all the audio in ``waveforms``
                have valid length. Default: ``None``.

        Returns:
            (Tensor, Optional[Tensor]):
            Tensor
                The inferred waveform of size `(n_batch, 1, n_time)`.
                1 stands for a single channel.
            Tensor or None
                If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
                is returned.
                It indicates the valid length in time axis of the output Tensor.
        """

        device = specgram.device
        dtype = specgram.dtype

        specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
        specgram, aux = self.upsample(specgram)
        if lengths is not None:
            lengths = lengths * self.upsample.total_scale

        output: List[Tensor] = []
        b_size, _, seq_len = specgram.size()

        h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
        h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
        x = torch.zeros((b_size, 1), device=device, dtype=dtype)

        aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]

        for i in range(seq_len):

            m_t = specgram[:, :, i]

            a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]

            x = torch.cat([x, m_t, a1_t], dim=1)
            x = self.fc(x)
            _, h1 = self.rnn1(x.unsqueeze(1), h1)

            x = x + h1[0]
            inp = torch.cat([x, a2_t], dim=1)
            _, h2 = self.rnn2(inp.unsqueeze(1), h2)

            x = x + h2[0]
            x = torch.cat([x, a3_t], dim=1)
            x = F.relu(self.fc1(x))

            x = torch.cat([x, a4_t], dim=1)
            x = F.relu(self.fc2(x))

            logits = self.fc3(x)

            posterior = F.softmax(logits, dim=1)

            x = torch.multinomial(posterior, 1).float()
            # Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
            x = 2 * x / (2 ** self.n_bits - 1.0) - 1.0

            output.append(x)

        return torch.stack(output).permute(1, 2, 0), lengths