def diag_jacobian()

in tensorflow_probability/python/math/diag_jacobian.py [0:0]


def diag_jacobian(xs,
                  ys=None,
                  sample_shape=None,
                  fn=None,
                  parallel_iterations=10,
                  name=None):
  """Computes diagonal of the Jacobian matrix of `ys=fn(xs)` wrt `xs`.

  If `ys` is a tensor or a list of tensors of the form `(ys_1, .., ys_n)` and
  `xs` is of the form `(xs_1, .., xs_n)`, the function `jacobians_diag`
  computes the diagonal of the Jacobian matrix, i.e., the partial derivatives
  `(dys_1/dxs_1,.., dys_n/dxs_n`). For definition details, see
  https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant

  #### Example

  ##### Diagonal Hessian of the log-density of a 3D Gaussian distribution

  In this example we sample from a standard univariate normal
  distribution using MALA with `step_size` equal to 0.75.

  ```python
  import tensorflow as tf
  import tensorflow_probability as tfp
  import numpy as np

  tfd = tfp.distributions

  dtype = np.float32
  with tf.Session(graph=tf.Graph()) as sess:
    true_mean = dtype([0, 0, 0])
    true_cov = dtype([[1, 0.25, 0.25], [0.25, 2, 0.25], [0.25, 0.25, 3]])
    chol = tf.linalg.cholesky(true_cov)
    target = tfd.MultivariateNormalTriL(loc=true_mean, scale_tril=chol)

    # Assume that the state is passed as a list of tensors `x` and `y`.
    # Then the target function is defined as follows:
    def target_fn(x, y):
      # Stack the input tensors together
      z = tf.concat([x, y], axis=-1) - true_mean
      return target.log_prob(z)

    sample_shape = [3, 5]
    state = [tf.ones(sample_shape + [2], dtype=dtype),
             tf.ones(sample_shape + [1], dtype=dtype)]
    fn_val, grads = tfp.math.value_and_gradient(target_fn, state)

    # We can either pass the `sample_shape` of the `state` or not, which impacts
    # computational speed of `diag_jacobian`
    _, diag_jacobian_shape_passed = diag_jacobian(
        xs=state, ys=grads, sample_shape=tf.shape(fn_val))
    _, diag_jacobian_shape_none = diag_jacobian(
        xs=state, ys=grads)

    diag_jacobian_shape_passed_ = sess.run(diag_jacobian_shape_passed)
    diag_jacobian_shape_none_ = sess.run(diag_jacobian_shape_none)

  print('hessian computed through `diag_jacobian`, sample_shape passed: ',
        np.concatenate(diag_jacobian_shape_passed_, -1))
  print('hessian computed through `diag_jacobian`, sample_shape skipped',
        np.concatenate(diag_jacobian_shape_none_, -1))

  ```

  Args:
    xs: `Tensor` or a python `list` of `Tensors` of real-like dtypes and shapes
      `sample_shape` + `event_shape_i`, where `event_shape_i` can be different
      for different tensors.
    ys: `Tensor` or a python `list` of `Tensors` of the same dtype as `xs`. Must
        broadcast with the shape of `xs`. Can be omitted if `fn` is provided.
    sample_shape: A common `sample_shape` of the input tensors of `xs`. If not,
      provided, assumed to be `[1]`, which may result in a slow performance of
      `jacobians_diag`.
    fn: Python callable that takes `xs` as an argument (or `*xs`, if it is a
      list) and returns `ys`. Might be skipped if `ys` is provided and
      `tf.enable_eager_execution()` is disabled.
    parallel_iterations: `int` that specifies the allowed number of coordinates
      of the input tensor `xs`, for which the partial derivatives `dys_i/dxs_i`
      can be computed in parallel.
    name: Python `str` name prefixed to `Ops` created by this function.
      Default value: `None` (i.e., "diag_jacobian").

  Returns:
    ys: a list, which coincides with the input `ys`, when provided.
      If the input `ys` is None, `fn(*xs)` gets computed and returned as a list.
    jacobians_diag_res: a `Tensor` or a Python list of `Tensor`s of the same
      dtypes and shapes as the input `xs`. This is the diagonal of the Jacobian
      of ys wrt xs.

  Raises:
    ValueError: if lists `xs` and `ys` have different length or both `ys` and
      `fn` are `None`, or `fn` is None in the eager execution mode.
  """
  with tf.name_scope(name or 'jacobians_diag'):
    if sample_shape is None:
      sample_shape = [1]
    # Output Jacobian diagonal
    jacobians_diag_res = []
    # Convert input `xs` to a list
    xs = list(xs) if _is_list_like(xs) else [xs]
    xs = [tf.convert_to_tensor(x) for x in xs]
    if not tf.executing_eagerly():
      if ys is None:
        if fn is None:
          raise ValueError('Both `ys` and `fn` can not be `None`')
        else:
          ys = fn(*xs)
      # Convert ys to a list
      ys = list(ys) if _is_list_like(ys) else [ys]
      if len(xs) != len(ys):
        raise ValueError('`xs` and `ys` should have the same length')
      for y, x in zip(ys, xs):
        # Broadcast `y` to the shape of `x`.
        y_ = y + tf.zeros_like(x)
        # Change `event_shape` to one-dimension
        y_ = tf.reshape(y, tf.concat([sample_shape, [-1]], -1))

        # Declare an iterator and tensor array loop variables for the gradients.
        n = tf.size(x) / tf.cast(tf.reduce_prod(sample_shape), dtype=tf.int32)
        n = tf.cast(n, dtype=tf.int32)
        loop_vars = [
            0,
            tf.TensorArray(x.dtype, n)
        ]

        def loop_body(j):
          """Loop function to compute gradients of the each direction."""
          # Gradient along direction `j`.
          res = tf.gradients(ys=y_[..., j], xs=x)[0]  # pylint: disable=cell-var-from-loop
          if res is None:
            # Return zero, if the gradient is `None`.
            res = tf.zeros(tf.concat([sample_shape, [1]], -1),
                           dtype=x.dtype)  # pylint: disable=cell-var-from-loop
          else:
            # Reshape `event_shape` to 1D
            res = tf.reshape(res, tf.concat([sample_shape, [-1]], -1))
            # Add artificial dimension for the case of zero shape input tensor
            res = res[tf.newaxis, ..., j]
          return res  # pylint: disable=cell-var-from-loop

        # Iterate over all elements of the gradient and compute second order
        # derivatives.
        _, jacobian_diag_res = tf.while_loop(
            cond=lambda j, _: j < n,  # pylint: disable=cell-var-from-loop
            body=lambda j, result: (j + 1, result.write(j, loop_body(j))),
            loop_vars=loop_vars,
            parallel_iterations=parallel_iterations)

        shape_x = ps.shape(x)
        # Stack gradients together and move flattened `event_shape` to the
        # zero position
        reshaped_jacobian_diag = tf.transpose(a=jacobian_diag_res.stack())
        # Reshape to the original tensor
        reshaped_jacobian_diag = tf.reshape(reshaped_jacobian_diag, shape_x)
        jacobians_diag_res.append(reshaped_jacobian_diag)

    else:
      if fn is None:
        raise ValueError('`fn` can not be `None` when eager execution is '
                         'enabled')
      if ys is None:
        ys = fn(*xs)

      def fn_slice(i, j):
        """Broadcast y[i], flatten event shape of y[i], return y[i][..., j]."""
        def fn_broadcast(*state):
          res = fn(*state)
          res = list(res) if _is_list_like(res) else [res]
          if len(res) != len(state):
            res *= len(state)
          res = [tf.reshape(r + tf.zeros_like(s),
                            ps.concat([sample_shape, [-1]], -1))
                 for r, s in zip(res, state)]
          return res
        # Expand dimensions before returning in order to support 0D input `xs`
        return lambda *state: tf.expand_dims(fn_broadcast(*state)[i], 0)[..., j]

      def make_loop_body(i, x):
        """Loop function to compute gradients of the each direction."""
        def _fn(j, result):
          res = value_and_gradient(fn_slice(i, j), xs)[1][i]
          if res is None:
            res = tf.zeros(sample_shape, dtype=x.dtype)
          else:
            res = tf.reshape(res, ps.concat([sample_shape, [-1]], -1))
            res = res[..., j]
          return j + 1, result.write(j, res)
        return _fn

      for i, x in enumerate(xs):
        # Declare an iterator and tensor array loop variables for the gradients.
        n = ps.size(x) / ps.cast(ps.reduce_prod(sample_shape), dtype=tf.int32)
        n = ps.cast(n, dtype=tf.int32)
        loop_vars = (0, tf.TensorArray(x.dtype, n, element_shape=sample_shape))

        # Iterate over all elements of the gradient and compute second order
        # derivatives.
        _, jacobian_diag_res = tf.while_loop(
            cond=lambda j, _: j < n,
            body=make_loop_body(i, x),
            loop_vars=loop_vars,
            parallel_iterations=parallel_iterations)

        shape_x = ps.shape(x)
        # Stack gradients together and move flattened `event_shape` to the
        # zero position
        reshaped_jacobian_diag = tf.transpose(jacobian_diag_res.stack())
        # Reshape to the original tensor
        reshaped_jacobian_diag = tf.reshape(reshaped_jacobian_diag, shape_x)
        jacobians_diag_res.append(reshaped_jacobian_diag)

  return ys, jacobians_diag_res