in inplace_abn/functions.py [0:0]
def backward(ctx, dy_act):
y_act, var, count, weight, bias = ctx.saved_tensors
# Call backward_reduce if we need to compute at least one of the gradients
if any(ctx.needs_input_grad):
xhat, dy, sum_dy_local, sum_xhat_dy_local = _backend.backward_reduce(
y_act,
dy_act,
weight,
bias,
ctx.eps,
ctx.activation,
ctx.activation_param,
)
if ctx.world_size > 1:
sum_dy, sum_xhat_dy = InPlaceABN._reduce_backward(
sum_dy_local, sum_xhat_dy_local, ctx.group
)
else:
sum_dy, sum_xhat_dy = sum_dy_local, sum_xhat_dy_local
else:
return (None,) * 12
# Gradient w.r.t. x
if ctx.needs_input_grad[0]:
if ctx.training:
# This overwrites dy with dx
_backend.backward_train(
xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, ctx.eps
)
dx = dy
else:
dx = _backend.backward_test(dy, var, weight, ctx.eps)
else:
dx = None
# Gradient w.r.t. weight
if weight is not None and ctx.needs_input_grad[1]:
dweight = sum_xhat_dy_local
dweight[weight < 0] *= -1
else:
dweight = None
# Gradient w.r.t. bias
if bias is not None and ctx.needs_input_grad[2]:
dbias = sum_dy_local
else:
dbias = None
return (dx, dweight, dbias) + (None,) * 9