in src/wavenet_generator.py [0:0]
def init(self, batch_size=None):
if batch_size is not None:
self.batch_size = batch_size
x = torch.zeros(self.batch_size, 1, 1)
x = x.cuda() if self.was_cuda else x
self.wavenet.first_conv.init_queue(x)
x = torch.zeros(self.batch_size, self.wavenet.residual_channels, 1)
x = x.cuda() if self.was_cuda else x
for layer in self.wavenet.layers:
layer.causal.init_queue(x)
if self.was_cuda:
self.wavenet.cuda()