in inplace_abn/functions.py [0:0]
def _reduce_forward(mean, var, count, group, world_size):
# Mean and variance
mean_var = torch.cat([mean, var], dim=0)
all_mean_var = mean_var.new_empty(world_size, mean_var.numel())
distributed.all_gather(
list(all_mean_var.unbind(0)), mean_var, group=group, async_op=False
)
all_mean, all_var = all_mean_var.split(mean.numel(), dim=1)
# Count
all_count = count.new_empty(world_size, 1)
distributed.all_gather(
list(all_count.unbind(0)), count, group=group, async_op=False
)
return _backend.reduce_statistics(all_mean, all_var, all_count)