in src/wavenet.py [0:0]
def __init__(self, args, create_layers=True, shift_input=True):
super().__init__()
self.blocks = args.blocks
self.layer_num = args.layers
self.kernel_size = args.kernel_size
self.skip_channels = args.skip_channels
self.residual_channels = args.residual_channels
self.cond_channels = args.latent_d
self.classes = 256
self.shift_input = shift_input
if create_layers:
layers = []
for _ in range(self.blocks):
for i in range(self.layer_num):
dilation = 2 ** i
layers.append(WavenetLayer(self.residual_channels, self.skip_channels, self.cond_channels,
self.kernel_size, dilation))
self.layers = nn.ModuleList(layers)
self.first_conv = CausalConv1d(1, self.residual_channels, kernel_size=self.kernel_size)
self.skip_conv = nn.Conv1d(self.residual_channels, self.skip_channels, kernel_size=1)
self.condition = nn.Conv1d(self.cond_channels, self.skip_channels, kernel_size=1)
self.fc = nn.Conv1d(self.skip_channels, self.skip_channels, kernel_size=1)
self.logits = nn.Conv1d(self.skip_channels, self.classes, kernel_size=1)