def backward()

in inplace_abn/functions.py [0:0]


    def backward(ctx, dy_act):
        y_act, var, count, weight, bias = ctx.saved_tensors

        # Call backward_reduce if we need to compute at least one of the gradients
        if any(ctx.needs_input_grad):
            xhat, dy, sum_dy_local, sum_xhat_dy_local = _backend.backward_reduce(
                y_act,
                dy_act,
                weight,
                bias,
                ctx.eps,
                ctx.activation,
                ctx.activation_param,
            )

            if ctx.world_size > 1:
                sum_dy, sum_xhat_dy = InPlaceABN._reduce_backward(
                    sum_dy_local, sum_xhat_dy_local, ctx.group
                )
            else:
                sum_dy, sum_xhat_dy = sum_dy_local, sum_xhat_dy_local
        else:
            return (None,) * 12

        # Gradient w.r.t. x
        if ctx.needs_input_grad[0]:
            if ctx.training:
                # This overwrites dy with dx
                _backend.backward_train(
                    xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, ctx.eps
                )
                dx = dy
            else:
                dx = _backend.backward_test(dy, var, weight, ctx.eps)
        else:
            dx = None

        # Gradient w.r.t. weight
        if weight is not None and ctx.needs_input_grad[1]:
            dweight = sum_xhat_dy_local
            dweight[weight < 0] *= -1
        else:
            dweight = None

        # Gradient w.r.t. bias
        if bias is not None and ctx.needs_input_grad[2]:
            dbias = sum_dy_local
        else:
            dbias = None

        return (dx, dweight, dbias) + (None,) * 9