def backward()

in apex/apex/parallel/sync_batchnorm_kernel.py [0:0]


    def backward(ctx, grad_output):
        torch.cuda.nvtx.range_push("sync_BN_bw")
        # mini batch mean & var are calculated by forward path.
        # mu = 1./N*np.sum(h, axis = 0)
        # var = 1./N*np.sum((h-mu)**2, axis = 0)
        c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors

        eps = ctx.eps
        process_group = ctx.process_group
        world_size = ctx.world_size
        grad_input = grad_weight = grad_bias = None
        num_features = running_mean.size()[0]

        # transpose it to channel last to support broadcasting for input with different rank
        torch.cuda.nvtx.range_push("carilli field")
        c_last_grad = grad_output.transpose(1, -1).contiguous()
        # squash non-channel dimension so we can easily calculate mean
        c_grad = c_last_grad.view(-1, num_features).contiguous()
        torch.cuda.nvtx.range_pop()

        # calculate grad_input
        if ctx.needs_input_grad[0]:
            # dh = gamma * (var + eps)**(-1. / 2.) * (dy - np.mean(dy, axis=0)
            #     - (h - mu) * (var + eps)**(-1.0) * np.mean(dy * (h - mu), axis=0))
            mean_dy = c_grad.mean(0)
            mean_dy_xmu = (c_last_grad * (c_last_input -
                                          running_mean)).view(-1, num_features).mean(0)
            if torch.distributed.is_initialized():
                torch.distributed.all_reduce(
                    mean_dy, ReduceOp.SUM, process_group)
                mean_dy = mean_dy / world_size
                torch.distributed.all_reduce(
                    mean_dy_xmu, ReduceOp.SUM, process_group)
                mean_dy_xmu = mean_dy_xmu / world_size
            c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (
                running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)
            if weight is not None:
                c_last_grad_input.mul_(weight)
            grad_input = c_last_grad_input.transpose(1, -1).contiguous()

        # calculate grad_weight
        grad_weight = None
        if weight is not None and ctx.needs_input_grad[1]:
            # dgamma = np.sum((h - mu) * (var + eps)**(-1. / 2.) * dy, axis=0)
            grad_weight = ((c_last_input - running_mean) / torch.sqrt(
                running_variance + eps) * c_last_grad).view(-1, num_features).sum(0)

        # calculate grad_bias
        grad_bias = None
        if bias is not None and ctx.needs_input_grad[2]:
            # dbeta = np.sum(dy, axis=0)
            grad_bias = c_grad.sum(0)

        torch.cuda.nvtx.range_pop()
        return grad_input, grad_weight, grad_bias, None, None, None, None, None