in Dassl.pytorch/dassl/modeling/ops/transnorm.py [0:0]
def forward(self, input):
self._check_input(input)
C = self.num_features
if input.dim() == 2:
new_shape = (1, C)
elif input.dim() == 4:
new_shape = (1, C, 1, 1)
else:
raise ValueError
weight = self.weight.view(*new_shape)
bias = self.bias.view(*new_shape)
if not self.training:
mean_t = self.running_mean_t.view(*new_shape)
var_t = self.running_var_t.view(*new_shape)
output = (input-mean_t) / (var_t + self.eps).sqrt()
output = output*weight + bias
if self.adaptive_alpha:
mean_s = self.running_mean_s.view(*new_shape)
var_s = self.running_var_s.view(*new_shape)
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output
input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)
x_s = input_s.transpose(0, 1).reshape(C, -1)
mean_s = x_s.mean(1)
var_s = x_s.var(1)
self.running_mean_s.mul_(self.momentum)
self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
self.running_var_s.mul_(self.momentum)
self.running_var_s.add_((1 - self.momentum) * var_s.data)
mean_s = mean_s.reshape(*new_shape)
var_s = var_s.reshape(*new_shape)
output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
output_s = output_s*weight + bias
x_t = input_t.transpose(0, 1).reshape(C, -1)
mean_t = x_t.mean(1)
var_t = x_t.var(1)
self.running_mean_t.mul_(self.momentum)
self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
self.running_var_t.mul_(self.momentum)
self.running_var_t.add_((1 - self.momentum) * var_t.data)
mean_t = mean_t.reshape(*new_shape)
var_t = var_t.reshape(*new_shape)
output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
output_t = output_t*weight + bias
output = torch.cat([output_s, output_t], 0)
if self.adaptive_alpha:
alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
alpha = alpha.reshape(*new_shape)
output = (1 + alpha.detach()) * output
return output