def spffn_forward()

in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]


    def spffn_forward(self, x, alpha=-1):
        out = self.module(x) # [out+eout, in+ein, H, W]

        patches = self.get_conv_patches(x)
        B, H, W, C_in, kh, kw = patches.size()
        C_out = out.shape[1]
        cin, cout = C_in - self.ein, C_out - self.eout

        x = patches.view(B*H*W, -1, kh*kw)

        if self.ein > 0:
            x1, x2 = x[:,:cin,:].view(B*H*W, -1), x[:,cin:,:].view(B*H*W,-1)
        else:
            x1 = x.view(B*H*W, -1)

        if self.can_split:
            noise_v = x1.mm(self.v.view(-1, cin*kh*kw).t()).view(B,H,W,-1).permute(0,3,1,2) # [B,cout,H,W]
            if alpha >= 0.:
                noise_v = (noise_v.detach() * self.y[:,:cout,:,:] + noise_v * alpha)

        if self.eout > 0:
            noise_vo = x1.mm(self.vno.view(-1, cin*kh*kw).t()).view(B,H,W,-1).permute(0,3,1,2)
            if alpha >= 0.:
                noise_vo = (noise_vo.detach() * self.y[:,cout:,:,:] + noise_vo * alpha)

        if self.ein > 0:
            noise_vi1 = x2.mm(self.vni.view(-1, self.ein*kh*kw).t())
            if self.eout > 0:
                noise_vi1, noise_vi2 = noise_vi1[:,:cout], noise_vi1[:,cout:] # [B*H*W, cout/eout]
                noise_vi1 = noise_vi1.view(B,H,W,-1).permute(0,3,1,2)
                noise_vi2 = noise_vi2.view(B,H,W,-1).permute(0,3,1,2)
            else:
                noise_vi1 = noise_vi1.view(B,H,W,-1).permute(0,3,1,2)

        o1_plus = o1_minus = o2 = 0.

        if self.can_split:
            o1_plus = out[:,:cout,:,:] + noise_v # [B, cout, H, W]
            o1_minus = out[:,:cout,:,:] - noise_v # [B, cout, H, W]
            if self.eout > 0:
                o2 = out[:,cout:,:,:] + noise_vo
            if self.ein > 0:
                o1_plus = o1_plus + noise_vi1
                o1_minus = o1_minus + noise_vi1
                if self.eout > 0:
                    o2 = o2 + noise_vi2
            if self.eout > 0:
                o1_plus = torch.cat((o1_plus, o2), 1)
                o1_minus = torch.cat((o1_minus, o2), 1)

            if self.has_bn:
                o1_plus = self.bn(o1_plus)
                o1_minus = self.bn(o1_minus)
            o1_plus = self._activate(o1_plus)
            o1_minus = self._activate(o1_minus)
            output = (o1_plus + o1_minus) / 2.
        else:
            o1 = out[:,:cout,:,:]
            if self.eout > 0:
                o2 = out[:,cout:,:,:] + noise_vo
                if self.ein > 0:
                    o2 = o2 + noise_vi2
            if self.ein > 0:
                o1 = o1 + noise_vi1
            if self.eout > 0:
                o1 = torch.cat((o1, o2), 1)
            if self.has_bn:
                o1 = self.bn(o1)
            output = self._activate(o1)
        self.output = output
        return output