in models/modules.py [0:0]
def forward(self, input):
weight, bias = self.get_weight()
c = input.shape[1]
if (
self.mode == "BatchNorm"
and self.width_factors_list is not None
and self.width_factor in self.width_factors_list
):
# Normally, we expect width_factors_list to be empty, because we
# only want to use it if we are running sanity checks (e.g.
# recreating the original performance or something).
idx = self.width_factors_list.index(self.width_factor)
kwargs = {
"input": input,
"running_mean": self.bn[idx].running_mean[:c],
"running_var": self.bn[idx].running_var[:c],
"weight": weight[:c],
"bias": bias[:c],
"training": self.training,
"momentum": self.momentum,
"eps": self.eps,
}
elif self.mode in ("InstanceNorm", "BatchNorm"):
# Sanity check, since we're not tracking running stats.
running_mean = self.running_mean
if self.running_mean is not None:
running_mean = running_mean[:c]
running_var = self.running_var
if self.running_var is not None:
running_var = running_var[:c]
kwargs = {
"input": input,
"running_mean": running_mean,
"running_var": running_var,
"weight": weight[:c],
"bias": bias[:c],
"momentum": self.momentum,
"eps": self.eps,
}
if self.mode == "BatchNorm":
kwargs["training"] = self.training
elif self.mode == "GroupNorm":
num_groups = self.num_groups
if num_groups == "full":
num_groups = c
kwargs = {
"input": input,
"num_groups": num_groups,
"weight": weight[:c],
"bias": bias[:c],
"eps": self.eps,
}
else:
raise NotImplementedError(f"Invalid mode {self.mode}.")
return self.bn_func(**kwargs)