in src/wavenet_generator.py [0:0]
def __init__(self, wavenet: WaveNet, batch_size=1, cond_repeat=800, wav_freq=16000):
super().__init__()
self.wavenet = wavenet
self.wavenet.shift_input = False
self.cond_repeat = cond_repeat
self.wav_freq = wav_freq
self.batch_size = batch_size
self.was_cuda = next(self.wavenet.parameters()).is_cuda
x = torch.zeros(self.batch_size, 1, 1)
x = x.cuda() if self.was_cuda else x
self.wavenet.first_conv = QueuedConv1d(self.wavenet.first_conv, x)
x = torch.zeros(self.batch_size, self.wavenet.residual_channels, 1)
x = x.cuda() if self.was_cuda else x
for layer in self.wavenet.layers:
layer.causal = QueuedConv1d(layer.causal, x)
if self.was_cuda:
self.wavenet.cuda()
self.wavenet.eval()