def _value_and_grad_impl()

in tensorflow_probability/python/math/gradient.py [0:0]


def _value_and_grad_impl(f, grad_fn, *args, output_gradients,
                         auto_unpack_single_arg,
                         expand_tf_modules_as_trainable_vars=False,
                         has_aux=False,
                         **kwargs):
  """Helper which generalizes gradient / Jacobian."""
  if not args and not kwargs:
    raise ValueError('Gradient is not defined unless at least one of `arg` or '
                     '`kwarg` is specified.')
  # The following is for backwards compatibility. In the one arg case with no
  # kwargs we can't tell which protocol to use if not for
  # `auto_unpack_single_arg`.  When `True` and when the sole arg is a tuple
  # or list then we unpack it as if it was the args, i.e., preserve the old
  # behavior.
  do_unpack = (auto_unpack_single_arg and len(args) == 1 and not(kwargs) and
               isinstance(args[0], (tuple, list)))
  if do_unpack:
    args = args[0]
  args, kwargs = _prepare_args(args, kwargs)
  if expand_tf_modules_as_trainable_vars:
    expand_args, expand_kwargs = tf.nest.map_structure(
        lambda x: x.trainable_variables if tensor_util.is_module(x) else x,
        [args, kwargs])
  else:
    expand_args, expand_kwargs = args, kwargs

  if not has_aux:
    real_f = f
    f = lambda *args, **kwargs: (real_f(*args, **kwargs)  # pylint: disable=g-long-lambda
                                 if _has_args(real_f) else real_f(), ())

  y, dydx, aux = grad_fn(lambda: f(*args, **kwargs) if _has_args(f) else f(),
                         tf.nest.flatten([expand_args, expand_kwargs]),
                         output_gradients)
  dydx_args, dydx_kwargs = tf.nest.pack_sequence_as(
      [expand_args, expand_kwargs], dydx)
  if len(args) == 1 and not do_unpack:
    dydx_args = dydx_args[0]

  if has_aux:
    res = ((y, aux),)
  else:
    res = (y,)

  if args:
    res += (dydx_args,)
  if kwargs:
    res += (dydx_kwargs,)

  return res