def __init__()

in src/wavenet_generator.py [0:0]


    def __init__(self, wavenet: WaveNet, batch_size=1, cond_repeat=800, wav_freq=16000):
        super().__init__()
        self.wavenet = wavenet
        self.wavenet.shift_input = False
        self.cond_repeat = cond_repeat
        self.wav_freq = wav_freq
        self.batch_size = batch_size
        self.was_cuda = next(self.wavenet.parameters()).is_cuda

        x = torch.zeros(self.batch_size, 1, 1)
        x = x.cuda() if self.was_cuda else x
        self.wavenet.first_conv = QueuedConv1d(self.wavenet.first_conv, x)

        x = torch.zeros(self.batch_size, self.wavenet.residual_channels, 1)
        x = x.cuda() if self.was_cuda else x
        for layer in self.wavenet.layers:
            layer.causal = QueuedConv1d(layer.causal, x)

        if self.was_cuda:
            self.wavenet.cuda()
        self.wavenet.eval()