in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]
def spffn_forward(self, x, alpha=-1):
out = self.module(x) # [out+eout, in+ein, H, W]
patches = self.get_conv_patches(x)
B, H, W, C_in, kh, kw = patches.size()
C_out = out.shape[1]
cin, cout = C_in - self.ein, C_out - self.eout
x = patches.view(B*H*W, -1, kh*kw)
if self.ein > 0:
x1, x2 = x[:,:cin,:].view(B*H*W, -1), x[:,cin:,:].view(B*H*W,-1)
else:
x1 = x.view(B*H*W, -1)
if self.can_split:
noise_v = x1.mm(self.v.view(-1, cin*kh*kw).t()).view(B,H,W,-1).permute(0,3,1,2) # [B,cout,H,W]
if alpha >= 0.:
noise_v = (noise_v.detach() * self.y[:,:cout,:,:] + noise_v * alpha)
if self.eout > 0:
noise_vo = x1.mm(self.vno.view(-1, cin*kh*kw).t()).view(B,H,W,-1).permute(0,3,1,2)
if alpha >= 0.:
noise_vo = (noise_vo.detach() * self.y[:,cout:,:,:] + noise_vo * alpha)
if self.ein > 0:
noise_vi1 = x2.mm(self.vni.view(-1, self.ein*kh*kw).t())
if self.eout > 0:
noise_vi1, noise_vi2 = noise_vi1[:,:cout], noise_vi1[:,cout:] # [B*H*W, cout/eout]
noise_vi1 = noise_vi1.view(B,H,W,-1).permute(0,3,1,2)
noise_vi2 = noise_vi2.view(B,H,W,-1).permute(0,3,1,2)
else:
noise_vi1 = noise_vi1.view(B,H,W,-1).permute(0,3,1,2)
o1_plus = o1_minus = o2 = 0.
if self.can_split:
o1_plus = out[:,:cout,:,:] + noise_v # [B, cout, H, W]
o1_minus = out[:,:cout,:,:] - noise_v # [B, cout, H, W]
if self.eout > 0:
o2 = out[:,cout:,:,:] + noise_vo
if self.ein > 0:
o1_plus = o1_plus + noise_vi1
o1_minus = o1_minus + noise_vi1
if self.eout > 0:
o2 = o2 + noise_vi2
if self.eout > 0:
o1_plus = torch.cat((o1_plus, o2), 1)
o1_minus = torch.cat((o1_minus, o2), 1)
if self.has_bn:
o1_plus = self.bn(o1_plus)
o1_minus = self.bn(o1_minus)
o1_plus = self._activate(o1_plus)
o1_minus = self._activate(o1_minus)
output = (o1_plus + o1_minus) / 2.
else:
o1 = out[:,:cout,:,:]
if self.eout > 0:
o2 = out[:,cout:,:,:] + noise_vo
if self.ein > 0:
o2 = o2 + noise_vi2
if self.ein > 0:
o1 = o1 + noise_vi1
if self.eout > 0:
o1 = torch.cat((o1, o2), 1)
if self.has_bn:
o1 = self.bn(o1)
output = self._activate(o1)
self.output = output
return output