in models/networks/sync_batchnorm/batchnorm.py [0:0]
def forward(self, input):
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
if not (self._is_parallel and self.training):
return F.batch_norm(
input,
self.running_mean,
self.running_var,
self.weight,
self.bias,
self.training,
self.momentum,
self.eps,
)
# Resize the input to (B, C, -1).
input_shape = input.size()
input = input.view(input.size(0), self.num_features, -1)
# Compute the sum and square-sum.
sum_size = input.size(0) * input.size(2)
input_sum = _sum_ft(input)
input_ssum = _sum_ft(input ** 2)
# Reduce-and-broadcast the statistics.
if self._parallel_id == 0:
mean, inv_std = self._sync_master.run_master(
_ChildMessage(input_sum, input_ssum, sum_size)
)
else:
mean, inv_std = self._slave_pipe.run_slave(
_ChildMessage(input_sum, input_ssum, sum_size)
)
# Compute the output.
if self.affine:
# MJY:: Fuse the multiplication for speed.
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(
inv_std * self.weight
) + _unsqueeze_ft(self.bias)
else:
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
# Reshape it.
return output.view(input_shape)