in tensorflow_probability/python/math/gradient.py [0:0]
def _value_and_grad_impl(f, grad_fn, *args, output_gradients,
auto_unpack_single_arg,
expand_tf_modules_as_trainable_vars=False,
has_aux=False,
**kwargs):
"""Helper which generalizes gradient / Jacobian."""
if not args and not kwargs:
raise ValueError('Gradient is not defined unless at least one of `arg` or '
'`kwarg` is specified.')
# The following is for backwards compatibility. In the one arg case with no
# kwargs we can't tell which protocol to use if not for
# `auto_unpack_single_arg`. When `True` and when the sole arg is a tuple
# or list then we unpack it as if it was the args, i.e., preserve the old
# behavior.
do_unpack = (auto_unpack_single_arg and len(args) == 1 and not(kwargs) and
isinstance(args[0], (tuple, list)))
if do_unpack:
args = args[0]
args, kwargs = _prepare_args(args, kwargs)
if expand_tf_modules_as_trainable_vars:
expand_args, expand_kwargs = tf.nest.map_structure(
lambda x: x.trainable_variables if tensor_util.is_module(x) else x,
[args, kwargs])
else:
expand_args, expand_kwargs = args, kwargs
if not has_aux:
real_f = f
f = lambda *args, **kwargs: (real_f(*args, **kwargs) # pylint: disable=g-long-lambda
if _has_args(real_f) else real_f(), ())
y, dydx, aux = grad_fn(lambda: f(*args, **kwargs) if _has_args(f) else f(),
tf.nest.flatten([expand_args, expand_kwargs]),
output_gradients)
dydx_args, dydx_kwargs = tf.nest.pack_sequence_as(
[expand_args, expand_kwargs], dydx)
if len(args) == 1 and not do_unpack:
dydx_args = dydx_args[0]
if has_aux:
res = ((y, aux),)
else:
res = (y,)
if args:
res += (dydx_args,)
if kwargs:
res += (dydx_kwargs,)
return res