def _reduce_forward()

in inplace_abn/functions.py [0:0]


    def _reduce_forward(mean, var, count, group, world_size):
        # Mean and variance
        mean_var = torch.cat([mean, var], dim=0)
        all_mean_var = mean_var.new_empty(world_size, mean_var.numel())
        distributed.all_gather(
            list(all_mean_var.unbind(0)), mean_var, group=group, async_op=False
        )
        all_mean, all_var = all_mean_var.split(mean.numel(), dim=1)

        # Count
        all_count = count.new_empty(world_size, 1)
        distributed.all_gather(
            list(all_count.unbind(0)), count, group=group, async_op=False
        )

        return _backend.reduce_statistics(all_mean, all_var, all_count)