in models/modules.py [0:0]
def forward(self, input):
# call get_weight, which samples from the subspace, then use the
# corresponding weight.
w, b = self.get_weight()
# The rest is code in the PyTorch source forward pass for batchnorm.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked = self.num_batches_tracked + 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(
self.num_batches_tracked
)
else: # use exponential moving average
exponential_average_factor = self.momentum
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (
self.running_var is None
)
return F.batch_norm(
input,
# If buffers are not to be tracked, ensure that they won't be
# updated
self.running_mean
if not self.training or self.track_running_stats
else None,
self.running_var
if not self.training or self.track_running_stats
else None,
w,
b,
bn_training,
exponential_average_factor,
self.eps,
)