def __init__()

in src/wavenet.py [0:0]


    def __init__(self, args, create_layers=True, shift_input=True):
        super().__init__()

        self.blocks = args.blocks
        self.layer_num = args.layers
        self.kernel_size = args.kernel_size
        self.skip_channels = args.skip_channels
        self.residual_channels = args.residual_channels
        self.cond_channels = args.latent_d
        self.classes = 256
        self.shift_input = shift_input

        if create_layers:
            layers = []
            for _ in range(self.blocks):
                for i in range(self.layer_num):
                    dilation = 2 ** i
                    layers.append(WavenetLayer(self.residual_channels, self.skip_channels, self.cond_channels,
                                               self.kernel_size, dilation))
            self.layers = nn.ModuleList(layers)

        self.first_conv = CausalConv1d(1, self.residual_channels, kernel_size=self.kernel_size)
        self.skip_conv = nn.Conv1d(self.residual_channels, self.skip_channels, kernel_size=1)
        self.condition = nn.Conv1d(self.cond_channels, self.skip_channels, kernel_size=1)
        self.fc = nn.Conv1d(self.skip_channels, self.skip_channels, kernel_size=1)
        self.logits = nn.Conv1d(self.skip_channels, self.classes, kernel_size=1)