def maybe_step()

in tensorflow_probability/python/math/ode/bdf.py [0:0]


    def maybe_step(accepted, diagnostics, iterand, solver_internal_state):
      """Takes a single step only if the outcome has a low enough error."""
      [
          num_jacobian_evaluations, num_matrix_factorizations,
          num_ode_fn_evaluations, status
      ] = diagnostics
      [
          jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps,
          num_steps_same_size, should_update_jacobian, should_update_step_size,
          time, unitary, upper
      ] = iterand
      [backward_differences, order, step_size] = solver_internal_state

      if max_num_steps is not None:
        status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0)

      backward_differences = tf1.where(
          should_update_step_size,
          bdf_util.interpolate_backward_differences(backward_differences, order,
                                                    new_step_size / step_size),
          backward_differences)
      step_size = tf1.where(should_update_step_size, new_step_size, step_size)
      should_update_factorization = should_update_step_size
      num_steps_same_size = tf1.where(should_update_step_size, 0,
                                      num_steps_same_size)

      def update_factorization():
        return bdf_util.newton_qr(jacobian_mat,
                                  newton_coefficients_array.read(order),
                                  step_size)

      if self._evaluate_jacobian_lazily:

        def update_jacobian_and_factorization():
          new_jacobian_mat = jacobian_fn_mat(time, backward_differences[0])
          new_unitary, new_upper = update_factorization()
          return [
              new_jacobian_mat, True, num_jacobian_evaluations + 1, new_unitary,
              new_upper
          ]

        def maybe_update_factorization():
          new_unitary, new_upper = tf.cond(
              should_update_factorization,
              update_factorization, lambda: [unitary, upper])
          return [
              jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations,
              new_unitary, new_upper
          ]

        [
            jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations,
            unitary, upper
        ] = tf.cond(should_update_jacobian, update_jacobian_and_factorization,
                    maybe_update_factorization)
      else:
        unitary, upper = update_factorization()
        num_matrix_factorizations += 1

      tol = p.atol + p.rtol * tf.abs(backward_differences[0])
      newton_tol = newton_tol_factor * tf.norm(tol)

      [
          newton_converged, next_backward_difference, next_state_vec,
          newton_num_iters
      ] = bdf_util.newton(backward_differences, max_num_newton_iters,
                          newton_coefficients_array.read(order), p.ode_fn_vec,
                          order, step_size, time, newton_tol, unitary, upper)
      num_steps += 1
      num_ode_fn_evaluations += newton_num_iters

      # If Newton's method failed and the Jacobian was up to date, decrease the
      # step size.
      newton_failed = tf.logical_not(newton_converged)
      should_update_step_size = newton_failed & jacobian_is_up_to_date
      new_step_size = step_size * tf1.where(should_update_step_size,
                                            newton_step_size_factor, 1.)

      # If Newton's method failed and the Jacobian was NOT up to date, update
      # the Jacobian.
      should_update_jacobian = newton_failed & tf.logical_not(
          jacobian_is_up_to_date)

      error_ratio = tf1.where(
          newton_converged,
          bdf_util.error_ratio(next_backward_difference,
                               error_coefficients_array.read(order), tol),
          np.nan)
      accepted = error_ratio < 1.
      converged_and_rejected = newton_converged & tf.logical_not(accepted)

      # If Newton's method converged but the solution was NOT accepted, decrease
      # the step size.
      new_step_size = tf1.where(
          converged_and_rejected,
          util.next_step_size(step_size, order, error_ratio, p.safety_factor,
                              min_step_size_factor, max_step_size_factor),
          new_step_size)
      should_update_step_size = should_update_step_size | converged_and_rejected

      # If Newton's method converged and the solution was accepted, update the
      # matrix of backward differences.
      time = tf1.where(accepted, time + step_size, time)
      backward_differences = tf1.where(
          accepted,
          bdf_util.update_backward_differences(backward_differences,
                                               next_backward_difference,
                                               next_state_vec, order),
          backward_differences)
      jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not(accepted)
      num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1,
                                      num_steps_same_size)

      # Order and step size are only updated if we have taken strictly more than
      # order + 1 steps of the same size. This is to prevent the order from
      # being throttled.
      should_update_order_and_step_size = accepted & (
          num_steps_same_size > order + 1)

      backward_differences_array = tf.TensorArray(
          backward_differences.dtype,
          size=bdf_util.MAX_ORDER + 3,
          clear_after_read=False,
          element_shape=next_backward_difference.shape).unstack(
              backward_differences)
      new_order = order
      new_error_ratio = error_ratio
      for offset in [-1, +1]:
        proposed_order = tf.clip_by_value(order + offset, 1, max_order)
        proposed_error_ratio = bdf_util.error_ratio(
            backward_differences_array.read(proposed_order + 1),
            error_coefficients_array.read(proposed_order), tol)
        proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio
        new_order = tf1.where(
            should_update_order_and_step_size & proposed_error_ratio_is_lower,
            proposed_order, new_order)
        new_error_ratio = tf1.where(
            should_update_order_and_step_size & proposed_error_ratio_is_lower,
            proposed_error_ratio, new_error_ratio)
      order = new_order
      error_ratio = new_error_ratio

      new_step_size = tf1.where(
          should_update_order_and_step_size,
          util.next_step_size(step_size, order, error_ratio, p.safety_factor,
                              min_step_size_factor, max_step_size_factor),
          new_step_size)
      should_update_step_size = (
          should_update_step_size | should_update_order_and_step_size)

      diagnostics = _BDFDiagnostics(num_jacobian_evaluations,
                                    num_matrix_factorizations,
                                    num_ode_fn_evaluations, status)
      iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date, new_step_size,
                            num_steps, num_steps_same_size,
                            should_update_jacobian, should_update_step_size,
                            time, unitary, upper)
      solver_internal_state = _BDFSolverInternalState(backward_differences,
                                                      order, step_size)
      return accepted, diagnostics, iterand, solver_internal_state