def forward()

in MaskRCNN/pytorch/maskrcnn_benchmark/layers/deconv.py [0:0]


    def forward(self, x,output_size=None):
        if x.numel()==0:
            return x
        N, C, H, W = x.shape
        B = self.block
        
        x=x.contiguous()
        if self.norm_type=='l1norm':
            x_norm=x.abs().mean(dim=(1,2,3),keepdim=True)
            x =  x/ (x_norm + self.eps)

        elif self.norm_type=='layernorm':
            x=self.layernorm(x)

        if self.training:
            self.counter+=1

            #1. im2col: N x cols x pixels -> N*pixles x cols
            if self.kernel_size[0]>1:
                #the adjoint of a conv operation is a full correlation operation. So pad first.
                padding=(self.padding[0]+self.kernel_size[0]-1,self.padding[1]+self.kernel_size[1]-1)
                X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
            else:
                #channel wise
                X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

            if self.groups==1:
                # (C//B*N*pixels,k*k*B)
                X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
            else:
                X=X.view(-1,X.shape[-1])

            # 2. calculate mean,cov,cov_isqrt

            X_mean = X.mean(0)

            if self.groups==1:
                M = X.shape[0]
                XX_mean = X.t()@X/M 
            else:
                M = X.shape[1]
                XX_mean = X.transpose(1,2)@X/M 

            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD
                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)
                    

            if self.groups==1:
                Id = torch.eye(XX_mean.shape[1], dtype=X.dtype, device=X.device)
                cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
            else:
                Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
                cov= (XX_mean- (X_mean.unsqueeze(2)) @X_mean.unsqueeze(1))+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd_batch(cov, self.n_iter)


            self.running_mean.mul_(1 - self.momentum)
            self.running_mean.add_(X_mean.detach() * self.momentum)
            
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)


        else:
            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt

        #3. S * X * D * W = (S * X) * (D * W)

        # the difference between conv and corr is the flipped kernel
        weight=torch.flip(self.weight,[2,3])

        if self.groups==1:
            w = weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
            w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
        else:
            w = self.weight.view(C//B, -1,self.num_features)@cov_isqrt
            b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)

        w = w.view(self.weight.shape)

        w=torch.flip(w.view(weight.shape),[2,3])

        output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size)
        
        x = F.conv_transpose2d(
            x, w, b, self.stride, self.padding,
            output_padding, self.groups, self.dilation)

        return x