Classification/models/deconv.py [20:510]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class Delinear(nn.Module):
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512,sync=False,norm_type='none'):
        super(Delinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

        if block > in_features:
            block = in_features
        else:
            if in_features%block!=0:
                block=math.gcd(block,in_features)
                print('block size set to:', block)
        self.block = block
        self.momentum = momentum
        self.n_iter = n_iter
        self.eps = eps
        self.register_buffer('running_mean', torch.zeros(self.block))
        self.register_buffer('running_cov_isqrt', torch.eye(self.block))
        self.norm_type=norm_type
        self.sync=sync
        self.process_group=None

        if self.norm_type=='layernorm':
            self.layernorm=LayerNorm(self.eps)

    def _specify_process_group(self, process_group):
        self.process_group = process_group

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        if input.numel()==0:
            return input

        input=input.contiguous()
        if self.norm_type=='l1norm':
            input_norm=input.abs().mean(dim=-1,keepdim=True)
            input =  input/ (input_norm + self.eps)
        if self.norm_type=='layernorm':
            input=self.layernorm(input)
            
        if self.training:
            # 1. reshape
            X=input.reshape(-1, self.block)

            # 2. calculate mean,cov,cov_isqrt

            N = X.shape[0]
            X_mean = X.mean(0)
            XX_mean = X.t()@X/N 

            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD

                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)                    
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)

            Id = torch.eye(XX_mean.shape[1], dtype=X.dtype, device=X.device)
            cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)+self.eps*Id
            cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
                           
            # track stats for evaluation
            self.running_mean.mul_(1 - self.momentum)
            self.running_mean.add_(X_mean.detach() * self.momentum)
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)
                
        else:

            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt

        #3. S * X * D * W = (S * X) * (D * W)

        w = self.weight.view(-1, self.block) @ cov_isqrt
        if self.bias is None:
            b = - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
        else:
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)

        w = w.view(self.weight.shape)
        return F.linear(input, w, b)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


class Deconv(conv._ConvNd):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100,sync=False,norm_type='none'):
  
        self.momentum = momentum
        self.n_iter = n_iter
        self.eps = eps
        self.counter=0
        super(Deconv, self).__init__(
            in_channels, out_channels,  _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
            False, _pair(0), groups, bias, padding_mode='zeros')


        if block > in_channels:
            block = in_channels
        else:
            if in_channels%block!=0:
                block=math.gcd(block,in_channels)

        if groups>1:
            #grouped conv
            block=in_channels//groups

        self.block=block

        self.num_features = kernel_size**2 *block


        if groups==1:
            self.register_buffer('running_mean', torch.zeros(self.num_features))
            self.register_buffer('running_cov_isqrt', torch.eye(self.num_features))
        else:
            self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
            self.register_buffer('running_cov_isqrt', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))

        self.sampling_stride=sampling_stride*stride
        self.counter=0
        self.freeze_iter=freeze_iter
        self.freeze=freeze
        self.norm_type=norm_type        
        self.sync=sync
        self.process_group=None
        
        if self.norm_type=='layernorm':
            self.layernorm=LayerNorm(self.eps)


            
    def _specify_process_group(self, process_group):
        self.process_group = process_group
    
    def forward(self, x):

        if x.numel()==0:
            return x
        N, C, H, W = x.shape
        B = self.block

        x=x.contiguous()
        if self.norm_type=='l1norm':
            x_norm=x.abs().mean(dim=(1,2,3),keepdim=True)            
            x =  x/ (x_norm + self.eps)

        elif self.norm_type=='layernorm':
            x=self.layernorm(x)


        if self.training:
            self.counter+=1
            frozen=self.freeze and (self.counter% (self.freeze_iter * 10) >self.freeze_iter)

        if self.training and (not frozen):

            # 1. im2col: N x cols x pixels -> N*pixles x cols
            if self.kernel_size[0]>1:
                X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()                    
            else:
                #channel wise
                X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

            if self.groups==1:
                # (C//B*N*pixels,k*k*B)
                X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
            else:
                X=X.view(-1,X.shape[-1])

            # 2. calculate mean,cov,cov_isqrt
            X_mean = X.mean(0)

            if self.groups==1:
                M = X.shape[0]
                XX_mean = X.t()@X/M 
            else:
                M = X.shape[1]
                XX_mean = X.transpose(1,2)@X/M 
                
            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD
        
                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)
                

            if self.groups==1:
                Id = torch.eye(XX_mean.shape[1], dtype=X.dtype, device=X.device)
                cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
            else:
                Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
                cov= (XX_mean- (X_mean.unsqueeze(2)) @X_mean.unsqueeze(1))+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd_batch(cov, self.n_iter)

            # track stats for evaluation.
            self.running_mean.mul_(1 - self.momentum)                
            self.running_mean.add_(X_mean.detach() * self.momentum)
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)

        else:
            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt

        #3. S * X * D * W = (S * X) * (D * W)
        if self.groups==1:
            w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
            w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
        else:
            w = self.weight.view(C//B, -1,self.num_features)@cov_isqrt
            b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)

        w = w.view(self.weight.shape)
        
        x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)

        return x


#class DeconvTransposed(conv._ConvTransposeNd): #latest pytorch
class DeconvTransposed(conv._ConvTransposeMixin):#backward compatibility

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1,bias=True, dilation=1, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,sync=False,norm_type='none'):

        self.momentum = momentum
        self.n_iter = n_iter
        self.eps = eps
        self.counter=0
        super(DeconvTransposed, self).__init__(
            in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
            True, _pair(output_padding), groups, bias, padding_mode='zeros')

        if block > in_channels:
            block = in_channels
        else:
            if in_channels%block!=0:
                block=math.gcd(block,in_channels)

        if groups>1:
            #grouped conv
            block=in_channels//groups

        self.block=block

        self.num_features = kernel_size**2 *block
        if groups==1:
            self.register_buffer('running_mean', torch.zeros(self.num_features))
            self.register_buffer('running_cov_isqrt', torch.eye(self.num_features))
        else:
            self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
            self.register_buffer('running_cov_isqrt', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))

        self.sampling_stride=sampling_stride*stride
        self.counter=0
        self.norm_type = norm_type
        self.sync=sync
        self.process_group=None

        if self.norm_type=='layernorm':
            self.layernorm=LayerNorm(self.eps)

    def _specify_process_group(self, process_group):
        self.process_group = process_group

    def forward(self, x,output_size=None):
        if x.numel()==0:
            return x
        N, C, H, W = x.shape
        B = self.block
        
        x=x.contiguous()
        if self.norm_type=='l1norm':
            x_norm=x.abs().mean(dim=(1,2,3),keepdim=True)
            x =  x/ (x_norm + self.eps)

        elif self.norm_type=='layernorm':
            x=self.layernorm(x)

        if self.training:
            self.counter+=1

            #1. im2col: N x cols x pixels -> N*pixles x cols
            if self.kernel_size[0]>1:
                #the adjoint of a conv operation is a full correlation operation. So pad first.
                padding=(self.padding[0]+self.kernel_size[0]-1,self.padding[1]+self.kernel_size[1]-1)
                X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
            else:
                #channel wise
                X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

            if self.groups==1:
                # (C//B*N*pixels,k*k*B)
                X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
            else:
                X=X.view(-1,X.shape[-1])

            # 2. calculate mean,cov,cov_isqrt

            X_mean = X.mean(0)

            if self.groups==1:
                M = X.shape[0]
                XX_mean = X.t()@X/M 
            else:
                M = X.shape[1]
                XX_mean = X.transpose(1,2)@X/M 

            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD
                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)
                    

            if self.groups==1:
                Id = torch.eye(XX_mean.shape[1], dtype=X.dtype, device=X.device)
                cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
            else:
                Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
                cov= (XX_mean- (X_mean.unsqueeze(2)) @X_mean.unsqueeze(1))+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd_batch(cov, self.n_iter)


            self.running_mean.mul_(1 - self.momentum)
            self.running_mean.add_(X_mean.detach() * self.momentum)
            
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)


        else:
            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt

        #3. S * X * D * W = (S * X) * (D * W)

        # the difference between conv and corr is the flipped kernel
        weight=torch.flip(self.weight,[2,3])

        if self.groups==1:
            w = weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
            w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
        else:
            w = self.weight.view(C//B, -1,self.num_features)@cov_isqrt
            b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)

        w = w.view(self.weight.shape)

        w=torch.flip(w.view(weight.shape),[2,3])

        output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size)
        
        x = F.conv_transpose2d(
            x, w, b, self.stride, self.padding,
            output_padding, self.groups, self.dilation)

        return x




def isqrt_newton_schulz_autograd(A, numIters,norm='norm',method='inverse_newton'):
    dim = A.shape[0]

    if norm=='norm':
        normA=A.norm()
    else:
        normA=A.trace()

    I = torch.eye(dim, dtype=A.dtype, device=A.device)
    Y = A.div(normA)

    Z = torch.eye(dim, dtype=A.dtype, device=A.device)

    if method=='denman_beavers':
        for i in range(numIters):
            #T = 0.5*(3.0*I - Z@Y)
            T=torch.addmm(beta=1.5, input=I, alpha=-0.5, mat1=Z, mat2=Y)
            Y = Y.mm(T)
            Z = T.mm(Z)
    elif method=='newton':
        for i in range(numIters):
            #Z =  1.5 * Z - 0.5* Z@ Z @ Z @ Y
            Z = torch.addmm(beta=1.5, input=Z, alpha=-0.5, mat1=torch.matrix_power(Z, 3), mat2=Y)
    elif method=='inverse_newton':
        for i in range(numIters):
            T= (3*I - Y)/2
            Y = torch.mm(torch.matrix_power(T, 2),Y)
            Z = Z.mm(T)

    #A_sqrt = Y* torch.sqrt(normA)
    A_isqrt =Z/ torch.sqrt(normA)
    return A_isqrt


def isqrt_newton_schulz_autograd_batch(A, numIters,method='inverse_newton'):
    batchSize,dim,_ = A.shape
    normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
    Y = A.div(normA)
    I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
    Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
    if method=='denman_beavers':
        for i in range(numIters):
            T = 0.5*(3.0*I - Z.bmm(Y))
            Y = Y.bmm(T)
            Z = T.bmm(Z)
    elif method=='inverse_newton':
        for i in range(numIters):
            T= (3*I - Y)/2
            Y = torch.bmm(torch.matrix_power(T, 2),Y)
            Z = Z.bmm(T)
    #A_sqrt = Y*torch.sqrt(normA)
    A_isqrt = Z / torch.sqrt(normA)

    return A_isqrt



class LayerNorm(nn.Module):
    def __init__(self, eps=1e-4):
        super(LayerNorm, self).__init__()
        self.eps=eps

    def forward(self,x):
        x_shape=x.shape
        x=x.reshape(x_shape[0],-1)
        mean = x.mean(-1,keepdim=True)
        std = x.std(-1,keepdim=True)+ self.eps
        #x = (x - mean) / std #disaster
        x = x /std- mean/std#this is way more efficient
        x=x.view(x_shape)
        return x




if __name__=='__main__':

    pass
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



MaskRCNN/pytorch/maskrcnn_benchmark/layers/deconv.py [20:510]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
class Delinear(nn.Module):
    __constants__ = ['bias', 'in_features', 'out_features']

    def __init__(self, in_features, out_features, bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=512,sync=False,norm_type='none'):
        super(Delinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

        if block > in_features:
            block = in_features
        else:
            if in_features%block!=0:
                block=math.gcd(block,in_features)
                print('block size set to:', block)
        self.block = block
        self.momentum = momentum
        self.n_iter = n_iter
        self.eps = eps
        self.register_buffer('running_mean', torch.zeros(self.block))
        self.register_buffer('running_cov_isqrt', torch.eye(self.block))
        self.norm_type=norm_type
        self.sync=sync
        self.process_group=None

        if self.norm_type=='layernorm':
            self.layernorm=LayerNorm(self.eps)

    def _specify_process_group(self, process_group):
        self.process_group = process_group

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        if input.numel()==0:
            return input

        input=input.contiguous()
        if self.norm_type=='l1norm':
            input_norm=input.abs().mean(dim=-1,keepdim=True)
            input =  input/ (input_norm + self.eps)
        if self.norm_type=='layernorm':
            input=self.layernorm(input)
            
        if self.training:
            # 1. reshape
            X=input.reshape(-1, self.block)

            # 2. calculate mean,cov,cov_isqrt

            N = X.shape[0]
            X_mean = X.mean(0)
            XX_mean = X.t()@X/N 

            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD

                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)                    
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)

            Id = torch.eye(XX_mean.shape[1], dtype=X.dtype, device=X.device)
            cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)+self.eps*Id
            cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
                           
            # track stats for evaluation
            self.running_mean.mul_(1 - self.momentum)
            self.running_mean.add_(X_mean.detach() * self.momentum)
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)
                
        else:

            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt

        #3. S * X * D * W = (S * X) * (D * W)

        w = self.weight.view(-1, self.block) @ cov_isqrt
        if self.bias is None:
            b = - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
        else:
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)

        w = w.view(self.weight.shape)
        return F.linear(input, w, b)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


class Deconv(conv._ConvNd):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,groups=1,bias=True, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,freeze=False,freeze_iter=100,sync=False,norm_type='none'):
  
        self.momentum = momentum
        self.n_iter = n_iter
        self.eps = eps
        self.counter=0
        super(Deconv, self).__init__(
            in_channels, out_channels,  _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
            False, _pair(0), groups, bias, padding_mode='zeros')


        if block > in_channels:
            block = in_channels
        else:
            if in_channels%block!=0:
                block=math.gcd(block,in_channels)

        if groups>1:
            #grouped conv
            block=in_channels//groups

        self.block=block

        self.num_features = kernel_size**2 *block


        if groups==1:
            self.register_buffer('running_mean', torch.zeros(self.num_features))
            self.register_buffer('running_cov_isqrt', torch.eye(self.num_features))
        else:
            self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
            self.register_buffer('running_cov_isqrt', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))

        self.sampling_stride=sampling_stride*stride
        self.counter=0
        self.freeze_iter=freeze_iter
        self.freeze=freeze
        self.norm_type=norm_type        
        self.sync=sync
        self.process_group=None
        
        if self.norm_type=='layernorm':
            self.layernorm=LayerNorm(self.eps)


            
    def _specify_process_group(self, process_group):
        self.process_group = process_group
    
    def forward(self, x):

        if x.numel()==0:
            return x
        N, C, H, W = x.shape
        B = self.block

        x=x.contiguous()
        if self.norm_type=='l1norm':
            x_norm=x.abs().mean(dim=(1,2,3),keepdim=True)            
            x =  x/ (x_norm + self.eps)

        elif self.norm_type=='layernorm':
            x=self.layernorm(x)


        if self.training:
            self.counter+=1
            frozen=self.freeze and (self.counter% (self.freeze_iter * 10) >self.freeze_iter)

        if self.training and (not frozen):

            # 1. im2col: N x cols x pixels -> N*pixles x cols
            if self.kernel_size[0]>1:
                X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()                    
            else:
                #channel wise
                X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

            if self.groups==1:
                # (C//B*N*pixels,k*k*B)
                X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
            else:
                X=X.view(-1,X.shape[-1])

            # 2. calculate mean,cov,cov_isqrt
            X_mean = X.mean(0)

            if self.groups==1:
                M = X.shape[0]
                XX_mean = X.t()@X/M 
            else:
                M = X.shape[1]
                XX_mean = X.transpose(1,2)@X/M 
                
            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD
        
                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)
                

            if self.groups==1:
                Id = torch.eye(XX_mean.shape[1], dtype=X.dtype, device=X.device)
                cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
            else:
                Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
                cov= (XX_mean- (X_mean.unsqueeze(2)) @X_mean.unsqueeze(1))+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd_batch(cov, self.n_iter)

            # track stats for evaluation.
            self.running_mean.mul_(1 - self.momentum)                
            self.running_mean.add_(X_mean.detach() * self.momentum)
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)

        else:
            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt

        #3. S * X * D * W = (S * X) * (D * W)
        if self.groups==1:
            w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
            w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
        else:
            w = self.weight.view(C//B, -1,self.num_features)@cov_isqrt
            b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)

        w = w.view(self.weight.shape)
        
        x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)

        return x


#class DeconvTransposed(conv._ConvTransposeNd): #latest pytorch
class DeconvTransposed(conv._ConvTransposeMixin):#backward compatibility

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1,bias=True, dilation=1, eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3,sync=False,norm_type='none'):

        self.momentum = momentum
        self.n_iter = n_iter
        self.eps = eps
        self.counter=0
        super(DeconvTransposed, self).__init__(
            in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
            True, _pair(output_padding), groups, bias, padding_mode='zeros')

        if block > in_channels:
            block = in_channels
        else:
            if in_channels%block!=0:
                block=math.gcd(block,in_channels)

        if groups>1:
            #grouped conv
            block=in_channels//groups

        self.block=block

        self.num_features = kernel_size**2 *block
        if groups==1:
            self.register_buffer('running_mean', torch.zeros(self.num_features))
            self.register_buffer('running_cov_isqrt', torch.eye(self.num_features))
        else:
            self.register_buffer('running_mean', torch.zeros(kernel_size ** 2 * in_channels))
            self.register_buffer('running_cov_isqrt', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))

        self.sampling_stride=sampling_stride*stride
        self.counter=0
        self.norm_type = norm_type
        self.sync=sync
        self.process_group=None

        if self.norm_type=='layernorm':
            self.layernorm=LayerNorm(self.eps)

    def _specify_process_group(self, process_group):
        self.process_group = process_group

    def forward(self, x,output_size=None):
        if x.numel()==0:
            return x
        N, C, H, W = x.shape
        B = self.block
        
        x=x.contiguous()
        if self.norm_type=='l1norm':
            x_norm=x.abs().mean(dim=(1,2,3),keepdim=True)
            x =  x/ (x_norm + self.eps)

        elif self.norm_type=='layernorm':
            x=self.layernorm(x)

        if self.training:
            self.counter+=1

            #1. im2col: N x cols x pixels -> N*pixles x cols
            if self.kernel_size[0]>1:
                #the adjoint of a conv operation is a full correlation operation. So pad first.
                padding=(self.padding[0]+self.kernel_size[0]-1,self.padding[1]+self.kernel_size[1]-1)
                X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
            else:
                #channel wise
                X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

            if self.groups==1:
                # (C//B*N*pixels,k*k*B)
                X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
            else:
                X=X.view(-1,X.shape[-1])

            # 2. calculate mean,cov,cov_isqrt

            X_mean = X.mean(0)

            if self.groups==1:
                M = X.shape[0]
                XX_mean = X.t()@X/M 
            else:
                M = X.shape[1]
                XX_mean = X.transpose(1,2)@X/M 

            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD
                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)
                    

            if self.groups==1:
                Id = torch.eye(XX_mean.shape[1], dtype=X.dtype, device=X.device)
                cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
            else:
                Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
                cov= (XX_mean- (X_mean.unsqueeze(2)) @X_mean.unsqueeze(1))+self.eps*Id
                cov_isqrt = isqrt_newton_schulz_autograd_batch(cov, self.n_iter)


            self.running_mean.mul_(1 - self.momentum)
            self.running_mean.add_(X_mean.detach() * self.momentum)
            
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)


        else:
            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt

        #3. S * X * D * W = (S * X) * (D * W)

        # the difference between conv and corr is the flipped kernel
        weight=torch.flip(self.weight,[2,3])

        if self.groups==1:
            w = weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
            w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
        else:
            w = self.weight.view(C//B, -1,self.num_features)@cov_isqrt
            b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)

        w = w.view(self.weight.shape)

        w=torch.flip(w.view(weight.shape),[2,3])

        output_padding = self._output_padding(x, output_size, self.stride, self.padding, self.kernel_size)
        
        x = F.conv_transpose2d(
            x, w, b, self.stride, self.padding,
            output_padding, self.groups, self.dilation)

        return x




def isqrt_newton_schulz_autograd(A, numIters,norm='norm',method='inverse_newton'):
    dim = A.shape[0]

    if norm=='norm':
        normA=A.norm()
    else:
        normA=A.trace()

    I = torch.eye(dim, dtype=A.dtype, device=A.device)
    Y = A.div(normA)

    Z = torch.eye(dim, dtype=A.dtype, device=A.device)

    if method=='denman_beavers':
        for i in range(numIters):
            #T = 0.5*(3.0*I - Z@Y)
            T=torch.addmm(beta=1.5, input=I, alpha=-0.5, mat1=Z, mat2=Y)
            Y = Y.mm(T)
            Z = T.mm(Z)
    elif method=='newton':
        for i in range(numIters):
            #Z =  1.5 * Z - 0.5* Z@ Z @ Z @ Y
            Z = torch.addmm(beta=1.5, input=Z, alpha=-0.5, mat1=torch.matrix_power(Z, 3), mat2=Y)
    elif method=='inverse_newton':
        for i in range(numIters):
            T= (3*I - Y)/2
            Y = torch.mm(torch.matrix_power(T, 2),Y)
            Z = Z.mm(T)

    #A_sqrt = Y* torch.sqrt(normA)
    A_isqrt =Z/ torch.sqrt(normA)
    return A_isqrt


def isqrt_newton_schulz_autograd_batch(A, numIters,method='inverse_newton'):
    batchSize,dim,_ = A.shape
    normA=A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
    Y = A.div(normA)
    I = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
    Z = torch.eye(dim,dtype=A.dtype,device=A.device).unsqueeze(0).expand_as(A)
    if method=='denman_beavers':
        for i in range(numIters):
            T = 0.5*(3.0*I - Z.bmm(Y))
            Y = Y.bmm(T)
            Z = T.bmm(Z)
    elif method=='inverse_newton':
        for i in range(numIters):
            T= (3*I - Y)/2
            Y = torch.bmm(torch.matrix_power(T, 2),Y)
            Z = Z.bmm(T)
    #A_sqrt = Y*torch.sqrt(normA)
    A_isqrt = Z / torch.sqrt(normA)

    return A_isqrt



class LayerNorm(nn.Module):
    def __init__(self, eps=1e-4):
        super(LayerNorm, self).__init__()
        self.eps=eps

    def forward(self,x):
        x_shape=x.shape
        x=x.reshape(x_shape[0],-1)
        mean = x.mean(-1,keepdim=True)
        std = x.std(-1,keepdim=True)+ self.eps
        #x = (x - mean) / std #disaster
        x = x /std- mean/std#this is way more efficient
        x=x.view(x_shape)
        return x




if __name__=='__main__':

    pass
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



