in src/hyperconv.py [0:0]
def forward(self, x, z):
'''
:param x: the input signal as a B x ch_in x T tensor
:param z: the weight-generating input as a B x z_dim x K tensor (K s.t. T is a multiple of K)
:return: a B x ch_out x T tensor as the result of the hyper-convolution
'''
B = x.shape[0]
assert x.shape[-1] % z.shape[-1] == 0
# padding
padding = self.dilation * (self.kernel_size - 1)
x = F.pad(x, [padding, 0])
# linearize input by appending receptive field in channels
start, end = padding, x.shape[-1]
x = th.cat([x[:, :, start-i*self.dilation:end-i*self.dilation] for i in range(self.kernel_size)], dim=1)
# rearrange input to blocks for matrix multiplication
x = x.permute(0, 2, 1).contiguous().view(x.shape[0] * z.shape[-1], x.shape[-1]//z.shape[-1], x.shape[1])
# compute weights and bias
weight = self.weight_model(z).view(B, self.ch_in * self.kernel_size, self.ch_out, z.shape[-1])
weight = weight.permute(0, 3, 1, 2).contiguous().view(B * z.shape[-1], self.ch_in * self.kernel_size, self.ch_out)
bias = self.bias_model(z).view(B, self.ch_out, z.shape[-1])
bias = bias.permute(0, 2, 1).contiguous().view(B * z.shape[-1], self.ch_out)
# compute result of dynamic convolution
y = th.bmm(x, weight)
y = y + bias[:, None, :]
y = y.view(B, -1, self.ch_out).permute(0, 2, 1).contiguous()
return y