in apex/apex/parallel/sync_batchnorm_kernel.py [0:0]
def backward(ctx, grad_output):
torch.cuda.nvtx.range_push("sync_BN_bw")
# mini batch mean & var are calculated by forward path.
# mu = 1./N*np.sum(h, axis = 0)
# var = 1./N*np.sum((h-mu)**2, axis = 0)
c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors
eps = ctx.eps
process_group = ctx.process_group
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None
num_features = running_mean.size()[0]
# transpose it to channel last to support broadcasting for input with different rank
torch.cuda.nvtx.range_push("carilli field")
c_last_grad = grad_output.transpose(1, -1).contiguous()
# squash non-channel dimension so we can easily calculate mean
c_grad = c_last_grad.view(-1, num_features).contiguous()
torch.cuda.nvtx.range_pop()
# calculate grad_input
if ctx.needs_input_grad[0]:
# dh = gamma * (var + eps)**(-1. / 2.) * (dy - np.mean(dy, axis=0)
# - (h - mu) * (var + eps)**(-1.0) * np.mean(dy * (h - mu), axis=0))
mean_dy = c_grad.mean(0)
mean_dy_xmu = (c_last_grad * (c_last_input -
running_mean)).view(-1, num_features).mean(0)
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, ReduceOp.SUM, process_group)
mean_dy = mean_dy / world_size
torch.distributed.all_reduce(
mean_dy_xmu, ReduceOp.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (
running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)
if weight is not None:
c_last_grad_input.mul_(weight)
grad_input = c_last_grad_input.transpose(1, -1).contiguous()
# calculate grad_weight
grad_weight = None
if weight is not None and ctx.needs_input_grad[1]:
# dgamma = np.sum((h - mu) * (var + eps)**(-1. / 2.) * dy, axis=0)
grad_weight = ((c_last_input - running_mean) / torch.sqrt(
running_variance + eps) * c_last_grad).view(-1, num_features).sum(0)
# calculate grad_bias
grad_bias = None
if bias is not None and ctx.needs_input_grad[2]:
# dbeta = np.sum(dy, axis=0)
grad_bias = c_grad.sum(0)
torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None, None, None