Classification/models/rfnorm.py [329:406]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
                if self.norm_type=='rfnorm': 
                    X_std=X.std(dim=-1,keepdim=True)+self.rf_eps
                    X=(X-X.mean(dim=-1,keepdim=True))/X_std
                    
            else:
                #channel wise
                X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

            if self.groups==1:
                # (C//B*N*pixels,k*k*B)
                X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
            else:
                X=X.view(-1,X.shape[-1])

            # 2. calculate mean,cov,cov_isqrt

            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD

                X_mean = X.mean(0)

                if self.groups==1:
                    M = X.shape[0]
                    XX_mean = X.t()@X/M 
                else:
                    M = X.shape[1]
                    XX_mean = X.transpose(1,2)@X/M 
                    
                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)
                    
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)
                    

                if self.groups==1:
                    cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)
                    Id = torch.eye(cov.shape[1], dtype=cov.dtype, device=cov.device)
                    cov_isqrt = isqrt_newton_schulz_autograd(cov+self.eps*Id, self.n_iter)
                else:
                    cov= (XX_mean- (X_mean.unsqueeze(2)) @X_mean.unsqueeze(1))
                    Id = torch.eye(self.num_features, dtype=cov.dtype, device=cov.device).expand(self.groups, self.num_features, self.num_features)
                    cov_isqrt = isqrt_newton_schulz_autograd_batch(cov+self.eps*Id, self.n_iter)

            else:
            
                X_mean = X.mean(0)
                X = X - X_mean.unsqueeze(0)

                if self.groups==1:
                    #cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
                    Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
                    cov = torch.addmm(beta=self.eps, input=Id, alpha=1. / X.shape[0], mat1=X.t(), mat2=X)
                    cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
                else:
                    X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
                    Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
                    cov = torch.baddbmm(beta=self.eps, input=Id, alpha=1. / X.shape[1],  mat1=X.transpose(1, 2),  mat2=X)
                    cov_isqrt = isqrt_newton_schulz_autograd_batch(cov, self.n_iter)

            # track stats for evaluation.
            self.running_mean.mul_(1 - self.momentum)                
            self.running_mean.add_(X_mean.detach() * self.momentum)
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)

        else:
            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



Classification/models/rfnorm.py [522:603]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
                if self.norm_type=='rfnorm': 
                    X_std=X.std(dim=-1,keepdim=True)+self.rf_eps
                    X=(X-X.mean(dim=-1,keepdim=True))/X_std
            else:
                #channel wise
                X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

            if self.groups==1:
                # (C//B*N*pixels,k*k*B)
                X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
            else:
                X=X.view(-1,X.shape[-1])


            if self.sync:
                process_group = self.process_group
                world_size = 1
                if not self.process_group:
                    process_group = torch.distributed.group.WORLD
            

                # 2. calculate mean,cov,cov_isqrt

                X_mean = X.mean(0)

                if self.groups==1:
                    M = X.shape[0]
                    XX_mean = X.t()@X/M 
                else:
                    M = X.shape[1]
                    XX_mean = X.transpose(1,2)@X/M 

                if torch.distributed.is_initialized():
                    world_size = torch.distributed.get_world_size(process_group)
                    
                    #sync once implementation:
                    sync_data=torch.cat([X_mean.view(-1),XX_mean.view(-1)],dim=0)
                    sync_data_list=[torch.empty_like(sync_data) for k in range(world_size)]
                    sync_data_list = diffdist.functional.all_gather(sync_data_list, sync_data)
                    sync_data=torch.stack(sync_data_list).mean(0)
                    X_mean=sync_data[:X_mean.numel()].view(X_mean.shape)
                    XX_mean=sync_data[X_mean.numel():].view(XX_mean.shape)
                    

                if self.groups==1:
                    cov= XX_mean- X_mean.unsqueeze(1) @X_mean.unsqueeze(0)
                    Id = torch.eye(cov.shape[1], dtype=cov.dtype, device=cov.device)
                    cov_isqrt = isqrt_newton_schulz_autograd(cov+self.eps*Id, self.n_iter)
                else:
                    cov= (XX_mean- (X_mean.unsqueeze(2)) @X_mean.unsqueeze(1))
                    Id = torch.eye(self.num_features, dtype=cov.dtype, device=cov.device).expand(self.groups, self.num_features, self.num_features)
                    cov_isqrt = isqrt_newton_schulz_autograd_batch(cov+self.eps*Id, self.n_iter)


            else:

                X_mean = X.mean(0)
                X = X - X_mean.unsqueeze(0)

                if self.groups==1:
                    #cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
                    Id=torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
                    cov = torch.addmm(beta=self.eps, input=Id, alpha=1. / X.shape[0], mat1=X.t(), mat2=X)
                    cov_isqrt = isqrt_newton_schulz_autograd(cov, self.n_iter)
                else:
                    X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
                    Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features, self.num_features)
                    cov = torch.baddbmm(beta=self.eps, input=Id, alpha=1. / X.shape[1], mat1=X.transpose(1, 2), mat2=X)
                    cov_isqrt = isqrt_newton_schulz_autograd_batch(cov, self.n_iter)


            self.running_mean.mul_(1 - self.momentum)
            self.running_mean.add_(X_mean.detach() * self.momentum)
            
            self.running_cov_isqrt.mul_(1 - self.momentum)
            self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)


        else:
            X_mean = self.running_mean
            cov_isqrt = self.running_cov_isqrt
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



