Classification/models/SyncND.py [172:271]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        if self.kernel_size[0]>1:
            X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
        else:
            #channel wise
            X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

        if self.groups==1:
            # (C//B*N*pixels,k*k*B)
            X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
        else:
            X=X.view(-1,X.shape[-1])


        # 2. subtract mean
        X_sum = X.sum(0)

        if self.groups==1:
            XY_sum = X.t()@X 
            sum_size = X.shape[0]
        else:
            XY_sum = X.transpose(1,2)@X 
            sum_size = X.shape[1]
        
        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            X_mean, cov_isqrt = self._sync_master.run_master(_ChildMessage(X_sum, XY_sum, sum_size))
        else:
            X_mean, cov_isqrt = self._slave_pipe.run_slave(_ChildMessage(X_sum, XY_sum, sum_size))

        
         #4. X * deconv * conv = X * (deconv * conv)
        if self.groups==1:
            w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
            w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
        else:
            w = self.weight.view(C//B, -1,self.num_features)@cov_isqrt
            b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)

        w = w.view(self.weight.shape)
        x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)

        return x



    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates, key=lambda i: i[1].X_sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].X_sum.get_device() for i in intermediates]

        total_sum = sum([i[1].sum_size for i in intermediates])
        X_sum, XY_sum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        X_mean, cov_isqrt = self._compute_mean_isqrt(X_sum, XY_sum, total_sum)

        broadcasted = Broadcast.apply(target_gpus, X_mean, cov_isqrt)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))

        return outputs

    def _compute_mean_isqrt(self, X_sum, XY_sum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        X_mean = X_sum / size

        if self.groups==1:
            Cov= (XY_sum- (X_sum.unsqueeze(1)) @X_mean.unsqueeze(0)) /size  #check here.
            Id=torch.eye(Cov.shape[1], dtype=Cov.dtype, device=Cov.device)
            cov_isqrt = isqrt_newton_schulz_autograd(Cov+self.eps*Id, self.n_iter)
        else:
            Cov= (XY_sum- (X_sum.unsqueeze(2)) @X_mean.unsqueeze(1)) /size  #check here.
            Id = torch.eye(self.num_features, dtype=Cov.dtype, device=Cov.device).expand(self.groups, self.num_features, self.num_features)
            cov_isqrt = isqrt_newton_schulz_autograd_batch(Cov+self.eps*Id, self.n_iter)

        self.running_mean.mul_(1 - self.momentum)
        self.running_mean.add_(X_mean.detach() * self.momentum)
        self.running_cov_isqrt.mul_(1 - self.momentum)
        self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)

        return X_mean, cov_isqrt
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



Segmentation/models/segmentation/SyncND.py [76:175]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            if self.kernel_size[0]>1:
                X = torch.nn.functional.unfold(x, self.kernel_size,self.dilation,self.padding,self.sampling_stride).transpose(1, 2).contiguous()
            else:
                #channel wise
                X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride**2,:]

            if self.groups==1:
                # (C//B*N*pixels,k*k*B)
                X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
            else:
                X=X.view(-1,X.shape[-1])


        # 2. subtract mean
        X_sum = X.sum(0)

        if self.groups==1:
            XY_sum = X.t()@X 
            sum_size = X.shape[0]
        else:
            XY_sum = X.transpose(1,2)@X 
            sum_size = X.shape[1]
        
        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            X_mean, cov_isqrt = self._sync_master.run_master(_ChildMessage(X_sum, XY_sum, sum_size))
        else:
            X_mean, cov_isqrt = self._slave_pipe.run_slave(_ChildMessage(X_sum, XY_sum, sum_size))

        
         #4. X * deconv * conv = X * (deconv * conv)
        if self.groups==1:
            w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,self.num_features) @ cov_isqrt
            b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
            w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
        else:
            w = self.weight.view(C//B, -1,self.num_features)@cov_isqrt
            b = self.bias - (w @ (X_mean.view( -1,self.num_features,1))).view(self.bias.shape)

        w = w.view(self.weight.shape)
        x= F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)

        return x



    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates, key=lambda i: i[1].X_sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].X_sum.get_device() for i in intermediates]

        total_sum = sum([i[1].sum_size for i in intermediates])
        X_sum, XY_sum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        X_mean, cov_isqrt = self._compute_mean_isqrt(X_sum, XY_sum, total_sum)

        broadcasted = Broadcast.apply(target_gpus, X_mean, cov_isqrt)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))

        return outputs

    def _compute_mean_isqrt(self, X_sum, XY_sum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        X_mean = X_sum / size

        if self.groups==1:
            Cov= (XY_sum- (X_sum.unsqueeze(1)) @X_mean.unsqueeze(0)) /size  #check here.
            Id=torch.eye(Cov.shape[1], dtype=Cov.dtype, device=Cov.device)
            cov_isqrt = isqrt_newton_schulz_autograd(Cov+self.eps*Id, self.n_iter)
        else:
            Cov= (XY_sum- (X_sum.unsqueeze(2)) @X_mean.unsqueeze(1)) /size  #check here.
            Id = torch.eye(self.num_features, dtype=Cov.dtype, device=Cov.device).expand(self.groups, self.num_features, self.num_features)
            cov_isqrt = isqrt_newton_schulz_autograd_batch(Cov+self.eps*Id, self.n_iter)

        self.running_mean.mul_(1 - self.momentum)
        self.running_mean.add_(X_mean.detach() * self.momentum)
        self.running_cov_isqrt.mul_(1 - self.momentum)
        self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)

        return X_mean, cov_isqrt
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



