pyro/distributions/transforms/utils.py (2 lines of code) (raw):
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0
def clamp_preserve_gradients(x, min, max):
# This helper function clamps gradients but still passes through the gradient in clamped regions
return x + (x.clamp(min, max) - x).detach()