in models/src/wavenet_vocoder/modules.py [0:0]
def _forward(self, x, c, g, is_incremental):
"""Forward
Args:
x (Tensor): B x C x T
c (Tensor): B x C x T, Local conditioning features
g (Tensor): B x C x T, Expanded global conditioning features
is_incremental (Bool) : Whether incremental mode or not
Returns:
Tensor: output
"""
residual = x
x = F.dropout(x, p=self.dropout, training=self.training)
if is_incremental:
splitdim = -1
x = self.conv.incremental_forward(x)
else:
splitdim = 1
x = self.conv(x)
# remove future time steps
x = x[:, :, : residual.size(-1)] if self.causal else x
a, b = x.split(x.size(splitdim) // 2, dim=splitdim)
# local conditioning
if c is not None:
assert self.conv1x1c is not None
c = _conv1x1_forward(self.conv1x1c, c, is_incremental)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
a, b = a + ca, b + cb
# global conditioning
if g is not None:
assert self.conv1x1g is not None
g = _conv1x1_forward(self.conv1x1g, g, is_incremental)
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
a, b = a + ga, b + gb
x = torch.tanh(a) * torch.sigmoid(b)
# For skip connection
s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental)
# For residual connection
x = _conv1x1_forward(self.conv1x1_out, x, is_incremental)
x = (x + residual) * math.sqrt(0.5)
return x, s