in horovod/torch/sync_batch_norm.py [0:0]
def backward(self, grad_output):
grad_output = grad_output.contiguous()
saved_input, weight, mean, invstd, count_all = self.saved_tensors
need_input_grad, need_weight_grad, need_bias_grad = self.needs_input_grad[0:3]
# calculate local stats as well as grad_weight / grad_bias
sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce(
grad_output,
saved_input,
mean,
invstd,
weight,
need_input_grad,
need_weight_grad,
need_bias_grad
)
if need_input_grad:
# synchronizing stats used to calculate input gradient.
sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy')
sum_dy_xmu_handle = allreduce_async(sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu')
# wait on the async communication to finish
sum_dy = synchronize(sum_dy_handle)
sum_dy_xmu = synchronize(sum_dy_xmu_handle)
if _SYNC_BN_V2 or _SYNC_BN_V3:
count_all_sum = count_all.sum()
mean_dy = sum_dy / count_all_sum
mean_dy_xmu = sum_dy_xmu / count_all_sum
else:
# before 1.5.0, sum_dy was sum of means from every worker, so we just
# need to divide it by number of workers
mean_dy = sum_dy / size()
mean_dy_xmu = sum_dy_xmu / size()
# backward pass for gradient calculation
grad_input = torch.batch_norm_backward_elemt(
grad_output,
saved_input,
mean,
invstd,
weight,
mean_dy,
mean_dy_xmu
)
else:
grad_input = None
# synchronizing of grad_weight / grad_bias is not needed as distributed
# training would handle all reduce.
if weight is None or not need_weight_grad:
grad_weight = None
if weight is None or not need_bias_grad:
grad_bias = None
return grad_input, grad_weight, grad_bias, None, None, None, None, None, None