in tensorflow_probability/python/math/ode/base.py [0:0]
def vjp_bwd(results_constants, dresults, variables=()):
"""Adjoint sensitivity method to compute gradients."""
results, constants = results_constants
adjoint_solver = self._make_adjoint_solver_fn()
dstates = dresults.states
# TODO(b/138304303): Support complex types.
with tf.name_scope('{}Gradients'.format(self._name)):
get_dtype = lambda x: x.dtype
def error_if_complex(dtype):
if dtype_util.is_complex(dtype):
raise NotImplementedError('The adjoint sensitivity method does '
'not support complex dtypes.')
state_dtypes = tf.nest.map_structure(get_dtype, initial_state)
tf.nest.map_structure(error_if_complex, state_dtypes)
common_state_dtype = dtype_util.common_dtype(initial_state)
real_dtype = dtype_util.real_dtype(common_state_dtype)
# We add initial_time to ensure that we know where to stop.
result_times = tf.concat(
[[tf.cast(initial_time, real_dtype)], results.times], 0)
num_result_times = tf.size(result_times)
# First two components correspond to reverse and adjoint states.
# the last two component is adjoint state for variables and constants.
terminal_augmented_state = tuple([
rk_util.nest_constant(initial_state, 0.0),
rk_util.nest_constant(initial_state, 0.0),
tuple(
rk_util.nest_constant(variable, 0.0) for variable in variables
),
rk_util.nest_constant(constants, 0.0),
])
# The XLA compiler does not compile code which slices/indexes using
# integer `Tensor`s. `TensorArray`s are used to get around this.
result_time_array = tf.TensorArray(
results.times.dtype,
clear_after_read=False,
size=num_result_times,
element_shape=[]).unstack(result_times)
# TensorArray shape should not include time dimension, hence shape[1:]
result_state_arrays = [
tf.TensorArray( # pylint: disable=g-complex-comprehension
dtype=component.dtype, size=num_result_times - 1,
clear_after_read=False,
element_shape=component.shape[1:]).unstack(component)
for component in tf.nest.flatten(results.states)
]
result_state_arrays = tf.nest.pack_sequence_as(
results.states, result_state_arrays)
dresult_state_arrays = [
tf.TensorArray( # pylint: disable=g-complex-comprehension
dtype=component.dtype, size=num_result_times - 1,
clear_after_read=False,
element_shape=component.shape[1:]).unstack(component)
for component in tf.nest.flatten(dstates)
]
dresult_state_arrays = tf.nest.pack_sequence_as(
results.states, dresult_state_arrays)
def augmented_ode_fn(backward_time, augmented_state):
"""Dynamics function for the augmented system.
Describes a differential equation that evolves the augmented state
backwards in time to compute gradients using the adjoint method.
Augmented state consists of 4 components `(state, adjoint_state,
vars, constants)` all evaluated at time `backward_time`:
state: represents the solution of user provided `ode_fn`. The
structure coincides with the `initial_state`.
adjoint_state: represents the solution of the adjoint sensitivity
differential equation as discussed below. Has the same structure
and shape as `state`.
variables: represent the solution of the adjoint equation for
variable gradients. Represented as a `Tuple(Tensor, ...)` with as
many tensors as there are `variables` variable outside this
function.
constants: represent the solution of the adjoint equation for
constant gradients. Has the same structure and shape as
`constants` variable outside this function.
The adjoint sensitivity equation describes the gradient of the
solution with respect to the value of the solution at a previous
time t. Its dynamics are given by
d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
Which is computed as:
d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
where in the last line we moved adj(t)_j under derivative by
removing gradient from it.
Adjoint equation for the gradient with respect to every
`tf.Variable` and constant theta follows:
d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
= -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
Args:
backward_time: Floating `Tensor` representing current time.
augmented_state: `Tuple(state, adjoint_state, variable_grads)`
Returns:
negative_derivatives: Structure of `Tensor`s equal to backwards
time derivative of the `state` componnent.
adjoint_ode: Structure of `Tensor`s equal to backwards time
derivative of the `adjoint_state` component.
adjoint_variables_ode: Structure of `Tensor`s equal to backwards
time derivative of the `vars` component.
adjoint_constants_ode: Structure of `Tensor`s equal to backwards
time derivative of the `constants` component.
"""
# The negative signs disappears after the change of variables.
# The ODE solver cannot handle the case initial_time > final_time
# and hence a change of variables backward_time = -time is used.
time = -backward_time
state, adjoint_state, _, _ = augmented_state
# TODO(b/152464477): Doesn't work reliably in TF1.
def grad_fn(state, variables, constants):
del variables # We compute these gradients via the GradientTape
# capturing them.
derivatives = ode_fn(time, state, **constants)
adjoint_no_grad = tf.nest.map_structure(tf.stop_gradient,
adjoint_state)
negative_derivatives = rk_util.weighted_sum([-1.0], [derivatives])
def dot_prod(tensor_a, tensor_b):
return tf.reduce_sum(tensor_a * tensor_b)
# See docstring for details.
adjoint_dot_derivatives = tf.nest.map_structure(
dot_prod, adjoint_no_grad, derivatives)
adjoint_dot_derivatives = tf.squeeze(
tf.add_n(tf.nest.flatten(adjoint_dot_derivatives)))
return adjoint_dot_derivatives, negative_derivatives
values = (state, tuple(variables), constants)
((_, negative_derivatives),
gradients) = tfp_gradient.value_and_gradient(
grad_fn, values, has_aux=True, use_gradient_tape=True)
(adjoint_ode, adjoint_variables_ode,
adjoint_constants_ode) = tf.nest.map_structure(
lambda v, g: tf.zeros_like(v) if g is None else g, values,
tuple(gradients))
return (negative_derivatives, adjoint_ode, adjoint_variables_ode,
adjoint_constants_ode)
def make_augmented_state(n, prev_augmented_state):
"""Constructs the augmented state for step `n`."""
(_, adjoint_state, adjoint_variable_state,
adjoint_constant_state) = prev_augmented_state
initial_state = _read_solution_components(
result_state_arrays,
input_state_structure,
n - 1,
)
initial_adjoint = _read_solution_components(
dresult_state_arrays,
input_state_structure,
n - 1,
)
initial_adjoint_state = rk_util.weighted_sum(
[1.0, 1.0], [adjoint_state, initial_adjoint])
augmented_state = (
initial_state,
initial_adjoint_state,
adjoint_variable_state,
adjoint_constant_state,
)
return augmented_state
def reverse_to_result_time(n, augmented_state, solver_internal_state,
_):
"""Integrates the augmented system backwards in time."""
lower_bound_of_integration = result_time_array.read(n)
upper_bound_of_integration = result_time_array.read(n - 1)
initial_augmented_state = make_augmented_state(n, augmented_state)
# TODO(b/138304303): Allow the user to specify the Hessian of
# `ode_fn` so that we can get the Jacobian of the adjoint system.
# TODO(b/143624114): Support higher order derivatives.
solver_internal_state = (
adjoint_solver._adjust_solver_internal_state_for_state_jump( # pylint: disable=protected-access
ode_fn=augmented_ode_fn,
initial_time=-lower_bound_of_integration,
initial_state=initial_augmented_state,
previous_solver_internal_state=solver_internal_state,
previous_state=augmented_state,
))
augmented_results = adjoint_solver.solve(
ode_fn=augmented_ode_fn,
initial_time=-lower_bound_of_integration,
initial_state=initial_augmented_state,
solution_times=[-upper_bound_of_integration],
batch_ndims=batch_ndims,
previous_solver_internal_state=solver_internal_state,
)
# Results added an extra time dim of size 1, squeeze it.
select_result = lambda x: tf.squeeze(x, [0])
result_state = augmented_results.states
result_state = tf.nest.map_structure(select_result, result_state)
status = augmented_results.diagnostics.status
return (n - 1, result_state,
augmented_results.solver_internal_state, status)
initial_n = num_result_times - 1
solver_internal_state = adjoint_solver._initialize_solver_internal_state( # pylint: disable=protected-access
ode_fn=augmented_ode_fn,
initial_time=result_time_array.read(initial_n),
initial_state=make_augmented_state(initial_n,
terminal_augmented_state),
)
_, augmented_state, _, _ = tf.while_loop(
lambda n, _as, _sis, status: (n >= 1) & tf.equal(status, 0),
reverse_to_result_time,
(initial_n, terminal_augmented_state, solver_internal_state, 0),
back_prop=False,
)
(_, adjoint_state, adjoint_variables,
adjoint_constants) = augmented_state
if variables:
return (adjoint_state, adjoint_constants), list(adjoint_variables)
else:
return adjoint_state, adjoint_constants