in torchaudio/models/wavernn.py [0:0]
def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
r"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss.
Args:
specgram (Tensor):
Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
lengths (Tensor or None, optional):
Indicates the valid length of each audio in the batch.
Shape: `(batch, )`.
When the ``specgram`` contains spectrograms with different durations,
by providing ``lengths`` argument, the model will compute
the corresponding valid output lengths.
If ``None``, it is assumed that all the audio in ``waveforms``
have valid length. Default: ``None``.
Returns:
(Tensor, Optional[Tensor]):
Tensor
The inferred waveform of size `(n_batch, 1, n_time)`.
1 stands for a single channel.
Tensor or None
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is returned.
It indicates the valid length in time axis of the output Tensor.
"""
device = specgram.device
dtype = specgram.dtype
specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
specgram, aux = self.upsample(specgram)
if lengths is not None:
lengths = lengths * self.upsample.total_scale
output: List[Tensor] = []
b_size, _, seq_len = specgram.size()
h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
x = torch.zeros((b_size, 1), device=device, dtype=dtype)
aux_split = [aux[:, self.n_aux * i : self.n_aux * (i + 1), :] for i in range(4)]
for i in range(seq_len):
m_t = specgram[:, :, i]
a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
x = torch.cat([x, m_t, a1_t], dim=1)
x = self.fc(x)
_, h1 = self.rnn1(x.unsqueeze(1), h1)
x = x + h1[0]
inp = torch.cat([x, a2_t], dim=1)
_, h2 = self.rnn2(inp.unsqueeze(1), h2)
x = x + h2[0]
x = torch.cat([x, a3_t], dim=1)
x = F.relu(self.fc1(x))
x = torch.cat([x, a4_t], dim=1)
x = F.relu(self.fc2(x))
logits = self.fc3(x)
posterior = F.softmax(logits, dim=1)
x = torch.multinomial(posterior, 1).float()
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
x = 2 * x / (2 ** self.n_bits - 1.0) - 1.0
output.append(x)
return torch.stack(output).permute(1, 2, 0), lengths