def forward()

in Dassl.pytorch/dassl/modeling/ops/transnorm.py [0:0]


    def forward(self, input):
        self._check_input(input)
        C = self.num_features
        if input.dim() == 2:
            new_shape = (1, C)
        elif input.dim() == 4:
            new_shape = (1, C, 1, 1)
        else:
            raise ValueError

        weight = self.weight.view(*new_shape)
        bias = self.bias.view(*new_shape)

        if not self.training:
            mean_t = self.running_mean_t.view(*new_shape)
            var_t = self.running_var_t.view(*new_shape)
            output = (input-mean_t) / (var_t + self.eps).sqrt()
            output = output*weight + bias

            if self.adaptive_alpha:
                mean_s = self.running_mean_s.view(*new_shape)
                var_s = self.running_var_s.view(*new_shape)
                alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
                alpha = alpha.reshape(*new_shape)
                output = (1 + alpha.detach()) * output

            return output

        input_s, input_t = torch.split(input, input.shape[0] // 2, dim=0)

        x_s = input_s.transpose(0, 1).reshape(C, -1)
        mean_s = x_s.mean(1)
        var_s = x_s.var(1)
        self.running_mean_s.mul_(self.momentum)
        self.running_mean_s.add_((1 - self.momentum) * mean_s.data)
        self.running_var_s.mul_(self.momentum)
        self.running_var_s.add_((1 - self.momentum) * var_s.data)
        mean_s = mean_s.reshape(*new_shape)
        var_s = var_s.reshape(*new_shape)
        output_s = (input_s-mean_s) / (var_s + self.eps).sqrt()
        output_s = output_s*weight + bias

        x_t = input_t.transpose(0, 1).reshape(C, -1)
        mean_t = x_t.mean(1)
        var_t = x_t.var(1)
        self.running_mean_t.mul_(self.momentum)
        self.running_mean_t.add_((1 - self.momentum) * mean_t.data)
        self.running_var_t.mul_(self.momentum)
        self.running_var_t.add_((1 - self.momentum) * var_t.data)
        mean_t = mean_t.reshape(*new_shape)
        var_t = var_t.reshape(*new_shape)
        output_t = (input_t-mean_t) / (var_t + self.eps).sqrt()
        output_t = output_t*weight + bias

        output = torch.cat([output_s, output_t], 0)

        if self.adaptive_alpha:
            alpha = self._compute_alpha(mean_s, var_s, mean_t, var_t)
            alpha = alpha.reshape(*new_shape)
            output = (1 + alpha.detach()) * output

        return output