def vjp_bwd()

in tensorflow_probability/python/math/ode/base.py [0:0]


    def vjp_bwd(results_constants, dresults, variables=()):
      """Adjoint sensitivity method to compute gradients."""
      results, constants = results_constants
      adjoint_solver = self._make_adjoint_solver_fn()
      dstates = dresults.states
      # TODO(b/138304303): Support complex types.
      with tf.name_scope('{}Gradients'.format(self._name)):
        get_dtype = lambda x: x.dtype
        def error_if_complex(dtype):
          if dtype_util.is_complex(dtype):
            raise NotImplementedError('The adjoint sensitivity method does '
                                      'not support complex dtypes.')

        state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
        tf.nest.map_structure(error_if_complex, state_dtypes)
        common_state_dtype = dtype_util.common_dtype(initial_state)
        real_dtype = dtype_util.real_dtype(common_state_dtype)

        # We add initial_time to ensure that we know where to stop.
        result_times = tf.concat(
            [[tf.cast(initial_time, real_dtype)], results.times], 0)
        num_result_times = tf.size(result_times)

        # First two components correspond to reverse and adjoint states.
        # the last two component is adjoint state for variables and constants.
        terminal_augmented_state = tuple([
            rk_util.nest_constant(initial_state, 0.0),
            rk_util.nest_constant(initial_state, 0.0),
            tuple(
                rk_util.nest_constant(variable, 0.0) for variable in variables
            ),
            rk_util.nest_constant(constants, 0.0),
        ])

        # The XLA compiler does not compile code which slices/indexes using
        # integer `Tensor`s. `TensorArray`s are used to get around this.
        result_time_array = tf.TensorArray(
            results.times.dtype,
            clear_after_read=False,
            size=num_result_times,
            element_shape=[]).unstack(result_times)

        # TensorArray shape should not include time dimension, hence shape[1:]
        result_state_arrays = [
            tf.TensorArray(  # pylint: disable=g-complex-comprehension
                dtype=component.dtype, size=num_result_times - 1,
                clear_after_read=False,
                element_shape=component.shape[1:]).unstack(component)
            for component in tf.nest.flatten(results.states)
        ]
        result_state_arrays = tf.nest.pack_sequence_as(
            results.states, result_state_arrays)
        dresult_state_arrays = [
            tf.TensorArray(  # pylint: disable=g-complex-comprehension
                dtype=component.dtype, size=num_result_times - 1,
                clear_after_read=False,
                element_shape=component.shape[1:]).unstack(component)
            for component in tf.nest.flatten(dstates)
        ]
        dresult_state_arrays = tf.nest.pack_sequence_as(
            results.states, dresult_state_arrays)

        def augmented_ode_fn(backward_time, augmented_state):
          """Dynamics function for the augmented system.

          Describes a differential equation that evolves the augmented state
          backwards in time to compute gradients using the adjoint method.
          Augmented state consists of 4 components `(state, adjoint_state,
          vars, constants)` all evaluated at time `backward_time`:

          state: represents the solution of user provided `ode_fn`. The
            structure coincides with the `initial_state`.
          adjoint_state: represents the solution of the adjoint sensitivity
            differential equation as discussed below. Has the same structure
            and shape as `state`.
          variables: represent the solution of the adjoint equation for
            variable gradients. Represented as a `Tuple(Tensor, ...)` with as
            many tensors as there are `variables` variable outside this
            function.
          constants: represent the solution of the adjoint equation for
            constant gradients. Has the same structure and shape as
            `constants` variable outside this function.

          The adjoint sensitivity equation describes the gradient of the
          solution with respect to the value of the solution at a previous
          time t. Its dynamics are given by
          d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
          Which is computed as:
          d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
          d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
          where in the last line we moved adj(t)_j under derivative by
          removing gradient from it.

          Adjoint equation for the gradient with respect to every
          `tf.Variable` and constant theta follows:
          d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
          = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

          Args:
            backward_time: Floating `Tensor` representing current time.
            augmented_state: `Tuple(state, adjoint_state, variable_grads)`

          Returns:
            negative_derivatives: Structure of `Tensor`s equal to backwards
              time derivative of the `state` componnent.
            adjoint_ode: Structure of `Tensor`s equal to backwards time
              derivative of the `adjoint_state` component.
            adjoint_variables_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `vars` component.
            adjoint_constants_ode: Structure of `Tensor`s equal to backwards
              time derivative of the `constants` component.
          """
          # The negative signs disappears after the change of variables.
          # The ODE solver cannot handle the case initial_time > final_time
          # and hence a change of variables backward_time = -time is used.
          time = -backward_time
          state, adjoint_state, _, _ = augmented_state

          # TODO(b/152464477): Doesn't work reliably in TF1.
          def grad_fn(state, variables, constants):
            del variables  # We compute these gradients via the GradientTape
            # capturing them.
            derivatives = ode_fn(time, state, **constants)
            adjoint_no_grad = tf.nest.map_structure(tf.stop_gradient,
                                                    adjoint_state)
            negative_derivatives = rk_util.weighted_sum([-1.0], [derivatives])

            def dot_prod(tensor_a, tensor_b):
              return tf.reduce_sum(tensor_a * tensor_b)

            # See docstring for details.
            adjoint_dot_derivatives = tf.nest.map_structure(
                dot_prod, adjoint_no_grad, derivatives)
            adjoint_dot_derivatives = tf.squeeze(
                tf.add_n(tf.nest.flatten(adjoint_dot_derivatives)))
            return adjoint_dot_derivatives, negative_derivatives

          values = (state, tuple(variables), constants)
          ((_, negative_derivatives),
           gradients) = tfp_gradient.value_and_gradient(
               grad_fn, values, has_aux=True, use_gradient_tape=True)

          (adjoint_ode, adjoint_variables_ode,
           adjoint_constants_ode) = tf.nest.map_structure(
               lambda v, g: tf.zeros_like(v) if g is None else g, values,
               tuple(gradients))
          return (negative_derivatives, adjoint_ode, adjoint_variables_ode,
                  adjoint_constants_ode)

        def make_augmented_state(n, prev_augmented_state):
          """Constructs the augmented state for step `n`."""
          (_, adjoint_state, adjoint_variable_state,
           adjoint_constant_state) = prev_augmented_state
          initial_state = _read_solution_components(
              result_state_arrays,
              input_state_structure,
              n - 1,
          )
          initial_adjoint = _read_solution_components(
              dresult_state_arrays,
              input_state_structure,
              n - 1,
          )
          initial_adjoint_state = rk_util.weighted_sum(
              [1.0, 1.0], [adjoint_state, initial_adjoint])
          augmented_state = (
              initial_state,
              initial_adjoint_state,
              adjoint_variable_state,
              adjoint_constant_state,
          )
          return augmented_state

        def reverse_to_result_time(n, augmented_state, solver_internal_state,
                                   _):
          """Integrates the augmented system backwards in time."""
          lower_bound_of_integration = result_time_array.read(n)
          upper_bound_of_integration = result_time_array.read(n - 1)
          initial_augmented_state = make_augmented_state(n, augmented_state)
          # TODO(b/138304303): Allow the user to specify the Hessian of
          # `ode_fn` so that we can get the Jacobian of the adjoint system.
          # TODO(b/143624114): Support higher order derivatives.
          solver_internal_state = (
              adjoint_solver._adjust_solver_internal_state_for_state_jump(  # pylint: disable=protected-access
                  ode_fn=augmented_ode_fn,
                  initial_time=-lower_bound_of_integration,
                  initial_state=initial_augmented_state,
                  previous_solver_internal_state=solver_internal_state,
                  previous_state=augmented_state,
              ))
          augmented_results = adjoint_solver.solve(
              ode_fn=augmented_ode_fn,
              initial_time=-lower_bound_of_integration,
              initial_state=initial_augmented_state,
              solution_times=[-upper_bound_of_integration],
              batch_ndims=batch_ndims,
              previous_solver_internal_state=solver_internal_state,
          )
          # Results added an extra time dim of size 1, squeeze it.
          select_result = lambda x: tf.squeeze(x, [0])
          result_state = augmented_results.states
          result_state = tf.nest.map_structure(select_result, result_state)
          status = augmented_results.diagnostics.status
          return (n - 1, result_state,
                  augmented_results.solver_internal_state, status)

        initial_n = num_result_times - 1
        solver_internal_state = adjoint_solver._initialize_solver_internal_state(  # pylint: disable=protected-access
            ode_fn=augmented_ode_fn,
            initial_time=result_time_array.read(initial_n),
            initial_state=make_augmented_state(initial_n,
                                               terminal_augmented_state),
        )

        _, augmented_state, _, _ = tf.while_loop(
            lambda n, _as, _sis, status: (n >= 1) & tf.equal(status, 0),
            reverse_to_result_time,
            (initial_n, terminal_augmented_state, solver_internal_state, 0),
            back_prop=False,
        )
        (_, adjoint_state, adjoint_variables,
         adjoint_constants) = augmented_state

        if variables:
          return (adjoint_state, adjoint_constants), list(adjoint_variables)
        else:
          return adjoint_state, adjoint_constants