in inplace_abn/abn.py [0:0]
def forward(self, x: torch.Tensor) -> torch.Tensor:
momentum, training = self._get_momentum_and_training()
running_mean, running_var = self._get_running_stats()
x = functional.batch_norm(
x,
running_mean,
running_var,
self.weight,
self.bias,
training,
momentum,
self.eps,
)
if self.activation == "relu":
return functional.relu(x, inplace=True)
elif self.activation == "leaky_relu":
return functional.leaky_relu(
x, negative_slope=self.activation_param, inplace=True
)
elif self.activation == "elu":
return functional.elu(x, alpha=self.activation_param, inplace=True)
elif self.activation == "identity":
return x
else:
raise RuntimeError(f"Unknown activation function {self.activation}")