in models/layers.py [0:0]
def forward(self, input, z):
batch, in_channel, height, width = input.shape
gamma = self.modulation(z).view(batch, 1, in_channel, 1, 1)
weight = self.scale * self.weight * gamma
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
weight = weight.view(batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size)
if self.upsample:
input = input.view(1, batch * in_channel, height, width)
weight = weight.view(batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size)
weight = weight.transpose(1, 2).reshape(
batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
)
out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
out = self.blur(out)
elif self.downsample:
input = self.blur(input)
_, _, height, width = input.shape
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
else:
input = input.view(1, batch * in_channel, height, width)
out = F.conv2d(input, weight, padding=self.padding, groups=batch)
_, _, height, width = out.shape
out = out.view(batch, self.out_channel, height, width)
if hasattr(self, 'activate'):
out = self.activate(out)
if hasattr(self, 'bias'):
out = out + self.bias
return out