Classification/models/SyncND.py [355:454]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return x



    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates, key=lambda i: i[1].X_sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].X_sum.get_device() for i in intermediates]

        total_sum = sum([i[1].sum_size for i in intermediates])
        X_sum, XY_sum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        X_mean, cov_isqrt = self._compute_mean_isqrt(X_sum, XY_sum, total_sum)

        broadcasted = Broadcast.apply(target_gpus, X_mean, cov_isqrt)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))

        return outputs

    def _compute_mean_isqrt(self, X_sum, XY_sum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        X_mean = X_sum / size

        if self.groups==1:
            Cov= (XY_sum- (X_sum.unsqueeze(1)) @X_mean.unsqueeze(0)) /size  #check here.
            Id=torch.eye(Cov.shape[1], dtype=Cov.dtype, device=Cov.device)
            cov_isqrt = isqrt_newton_schulz_autograd(Cov+self.eps*Id, self.n_iter)
        else:
            Cov= (XY_sum- (X_sum.unsqueeze(2)) @X_mean.unsqueeze(1)) /size  #check here.
            Id = torch.eye(self.num_features, dtype=Cov.dtype, device=Cov.device).expand(self.groups, self.num_features, self.num_features)
            cov_isqrt = isqrt_newton_schulz_autograd_batch(Cov+self.eps*Id, self.n_iter)

        self.running_mean.mul_(1 - self.momentum)
        self.running_mean.add_(X_mean.detach() * self.momentum)
        self.running_cov_isqrt.mul_(1 - self.momentum)
        self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)

        return X_mean, cov_isqrt




@contextlib.contextmanager
def patch_sync_nd():
    import torch.nn as nn

    backup = FastDeconv, Delinear

    FastDeconv = SynchronizedDeconv
    Delinear = SynchronizedDelinear

    yield

    FastDeconv, Delinear = backup


def convert_sync_nd(module):
    """Traverse the input module and its child recursively
       and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
       to SynchronizedBatchNorm*N*d

    Args:
        module: the input module needs to be convert to SyncBN model

    Examples:
        >>> import torch.nn as nn
        >>> import torchvision
        >>> # m is a standard pytorch model
        >>> m = torchvision.models.resnet18(True)
        >>> m = nn.DataParallel(m)
        >>> # after convert, m is using SyncBN
        >>> m = convert_model(m)
    """
    if isinstance(module, torch.nn.DataParallel):
        mod = module.module
        mod = convert_sync_nd(mod)
        mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
        return mod

    mod = module
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



Segmentation/models/segmentation/SyncND.py [118:217]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return x



    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates, key=lambda i: i[1].X_sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].X_sum.get_device() for i in intermediates]

        total_sum = sum([i[1].sum_size for i in intermediates])
        X_sum, XY_sum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        X_mean, cov_isqrt = self._compute_mean_isqrt(X_sum, XY_sum, total_sum)

        broadcasted = Broadcast.apply(target_gpus, X_mean, cov_isqrt)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))

        return outputs

    def _compute_mean_isqrt(self, X_sum, XY_sum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        X_mean = X_sum / size

        if self.groups==1:
            Cov= (XY_sum- (X_sum.unsqueeze(1)) @X_mean.unsqueeze(0)) /size  #check here.
            Id=torch.eye(Cov.shape[1], dtype=Cov.dtype, device=Cov.device)
            cov_isqrt = isqrt_newton_schulz_autograd(Cov+self.eps*Id, self.n_iter)
        else:
            Cov= (XY_sum- (X_sum.unsqueeze(2)) @X_mean.unsqueeze(1)) /size  #check here.
            Id = torch.eye(self.num_features, dtype=Cov.dtype, device=Cov.device).expand(self.groups, self.num_features, self.num_features)
            cov_isqrt = isqrt_newton_schulz_autograd_batch(Cov+self.eps*Id, self.n_iter)

        self.running_mean.mul_(1 - self.momentum)
        self.running_mean.add_(X_mean.detach() * self.momentum)
        self.running_cov_isqrt.mul_(1 - self.momentum)
        self.running_cov_isqrt.add_(cov_isqrt.detach() * self.momentum)

        return X_mean, cov_isqrt




@contextlib.contextmanager
def patch_sync_nd():
    import torch.nn as nn

    backup = FastDeconv, Delinear

    FastDeconv = SynchronizedDeconv
    Delinear = SynchronizedDelinear

    yield

    FastDeconv, Delinear = backup


def convert_sync_nd(module):
    """Traverse the input module and its child recursively
       and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
       to SynchronizedBatchNorm*N*d

    Args:
        module: the input module needs to be convert to SyncBN model

    Examples:
        >>> import torch.nn as nn
        >>> import torchvision
        >>> # m is a standard pytorch model
        >>> m = torchvision.models.resnet18(True)
        >>> m = nn.DataParallel(m)
        >>> # after convert, m is using SyncBN
        >>> m = convert_model(m)
    """
    if isinstance(module, torch.nn.DataParallel):
        mod = module.module
        mod = convert_sync_nd(mod)
        mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
        return mod

    mod = module
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



