in models/wavernn.py [0:0]
def generate(self, spectrograms: Tensor, training: bool = False) -> Tensor:
"""
Generate a sample from this model.
Returns:
A 1D float tensor containing the output waveform.
"""
self.model.eval()
if training:
spectrograms = spectrograms[:, :, :200]
output = []
rnn1 = self.get_gru_cell(self.model.module.rnn1)
rnn2 = self.get_gru_cell(self.model.module.rnn2)
with torch.no_grad():
spectrograms, aux = self.model.module.upsample(spectrograms)
spectrograms = spectrograms.transpose(1, 2)
aux = aux.transpose(1, 2)
batch_size, seq_len, _ = spectrograms.size()
h1 = spectrograms.new_zeros(batch_size, self.model.module.n_rnn)
h2 = spectrograms.new_zeros(batch_size, self.model.module.n_rnn)
x = spectrograms.new_zeros(batch_size, 1)
d = self.model.module.n_aux
aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)]
for i in tqdm(range(seq_len)):
m_t = spectrograms[:, i, :]
a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split)
x = torch.cat([x, m_t, a1_t], dim=1)
x = self.model.module.fc(x)
h1 = rnn1(x, h1)
x = x + h1
inp = torch.cat([x, a2_t], dim=1)
h2 = rnn2(inp, h2)
x = x + h2
x = torch.cat([x, a3_t], dim=1)
x = F.relu(self.model.module.fc1(x))
x = torch.cat([x, a4_t], dim=1)
x = F.relu(self.model.module.fc2(x))
logits = self.model.module.fc3(x)
x, output = self.get_x_from_dist("random", logits, output)
output = torch.stack(output).transpose(0, 1)
output = self.expand(output.flatten())
self.model.train()
return output