in svoice/models/swave.py [0:0]
def forward(self, input):
batch_size, _, d1, d2 = input.shape
output = input
output_all = []
for i in range(self.num_layers):
row_input = output.permute(0, 3, 2, 1).contiguous().view(
batch_size * d2, d1, -1)
row_output = self.rows_grnn[i](row_input)
row_output = row_output.view(
batch_size, d2, d1, -1).permute(0, 3, 2, 1).contiguous()
row_output = self.rows_normalization[i](row_output)
# apply a skip connection
if self.training:
output = output + row_output
else:
output += row_output
col_input = output.permute(0, 2, 3, 1).contiguous().view(
batch_size * d1, d2, -1)
col_output = self.cols_grnn[i](col_input)
col_output = col_output.view(
batch_size, d1, d2, -1).permute(0, 3, 1, 2).contiguous()
col_output = self.cols_normalization[i](col_output).contiguous()
# apply a skip connection
if self.training:
output = output + col_output
else:
output += col_output
output_i = self.output(output)
if self.training or i == (self.num_layers - 1):
output_all.append(output_i)
return output_all