in src/wavenet.py [0:0]
def forward(self, x, c=None):
if x.dim() < 3:
x = x.unsqueeze(1)
if (not 'Half' in x.type()) and (not 'Float' in x.type()):
x = x.float()
x = x / 255 - 0.5
if self.shift_input:
x = self.shift_right(x)
if c is not None:
c = self._upsample_cond(x, c)
residual = self.first_conv(x)
skip = self.skip_conv(residual)
for layer in self.layers:
r, s = layer(residual, c)
residual = residual + r
skip = skip + s
skip = F.relu(skip)
skip = self.fc(skip)
if c is not None:
skip = self._condition(skip, c, self.condition)
skip = F.relu(skip)
skip = self.logits(skip)
return skip