in discussion/turnkey_inference_candidate/window_tune_nuts_sampling.py [0:0]
def _sample_posterior(target_log_prob_unconstrained,
prior_samples_unconstrained,
init_state=None,
num_samples=500,
nchains=4,
init_nchains=1,
target_accept_prob=.8,
max_tree_depth=9,
use_scaled_init=True,
tuning_window_schedule=(75, 25, 25, 25, 25, 25, 50),
use_wide_window_expanding_mode=True,
seed=None,
parallel_iterations=10,
jit_compile=True,
use_input_signature=False,
experimental_relax_shapes=False):
"""MCMC sampling with HMC/NUTS using an expanding epoch tuning scheme."""
seed_stream = tfp.util.SeedStream(seed, 'window_tune_nuts_sampling')
rv_rank = ps.rank(prior_samples_unconstrained)
assert rv_rank == 2
total_ndims = ps.shape(prior_samples_unconstrained)[-1]
dtype = prior_samples_unconstrained.dtype
# TODO(b/158878248): explore option to for user to control the
# parameterization of conditioning_bijector.
# TODO(b/158878248): right now, we use 2 tf.Variable to initialize a scaling
# bijector, and update the underlying value at the end of each warmup window.
# It might be faster to rewrite it into a functional style (with a small
# additional compilation cost).
loc_conditioner = tf.Variable(
tf.zeros([total_ndims], dtype=dtype), name='loc_conditioner')
scale_conditioner = tf.Variable(
tf.ones([total_ndims], dtype=dtype), name='scale_conditioner')
# Start with Identity Covariance matrix
scale = tf.linalg.LinearOperatorDiag(
diag=scale_conditioner,
is_non_singular=True,
is_self_adjoint=True,
is_positive_definite=True)
conditioning_bijector = tfb.Shift(shift=loc_conditioner)(
tfb.ScaleMatvecLinearOperator(scale))
if init_state is None:
# Start at uniform random [-1, 1] around the prior mean in latent space
init_state_uniform = tf.random.uniform(
[init_nchains, total_ndims], dtype=dtype, seed=seed_stream()) * 2. - 1.
if use_scaled_init:
prior_z_mean = tf.math.reduce_mean(prior_samples_unconstrained, axis=0)
prior_z_std = tf.math.reduce_std(prior_samples_unconstrained, axis=0)
init_state = init_state_uniform * prior_z_std + prior_z_mean
else:
init_state = init_state_uniform
# The denominator is the O(N^0.25) scaling from Beskos et al. 2010. The
# numerator corresponds to the trajectory length. Candidate value includs: 1,
# 1.57 (pi / 2). We use a conservately small value here (0.25).
init_step_size = tf.constant(0.25 / (total_ndims**0.25), dtype=dtype)
hmc_inner = tfp.mcmc.TransformedTransitionKernel(
tfp.mcmc.NoUTurnSampler(
target_log_prob_fn=target_log_prob_unconstrained,
step_size=init_step_size,
max_tree_depth=max_tree_depth,
parallel_iterations=parallel_iterations,
), conditioning_bijector)
hmc_step_size_tuning = tfp.mcmc.DualAveragingStepSizeAdaptation(
inner_kernel=hmc_inner,
num_adaptation_steps=max(tuning_window_schedule),
target_accept_prob=target_accept_prob)
if use_input_signature:
input_signature = [
tf.TensorSpec(shape=None, dtype=tf.int32),
tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
]
else:
input_signature = None
# TODO(b/158878248): move the nested function definitions to module top-level.
@tf.function(
input_signature=input_signature,
autograph=False,
jit_compile=jit_compile,
experimental_relax_shapes=experimental_relax_shapes)
def fast_adaptation_interval(num_steps, previous_state):
"""Step size only adaptation interval.
This corresponds to window 1 and window 3 in the Stan HMC parameter
tuning scheme.
Args:
num_steps: Number of tuning steps the interval will run.
previous_state: Initial state of the tuning interval.
Returns:
last_state: Last state of the tuning interval.
last_pkr: Kernel result from the TransitionKernel at the end of the
tuning interval.
"""
def body_fn(i, state, pkr):
next_state, next_pkr = hmc_step_size_tuning.one_step(state, pkr)
return i + 1, next_state, next_pkr
current_pkr = hmc_step_size_tuning.bootstrap_results(previous_state)
_, last_state, last_pkr = tf.while_loop(
lambda i, *_: i < num_steps,
body_fn,
loop_vars=(0, previous_state, current_pkr),
maximum_iterations=num_steps,
parallel_iterations=parallel_iterations)
return last_state, last_pkr
def body_fn_window2(
i, previous_state, previous_pkr, previous_mean, previous_cov):
"""Take one MCMC step and update the step size and mass matrix."""
next_state, next_pkr = hmc_step_size_tuning.one_step(
previous_state, previous_pkr)
n_next = i + 1
delta_pre = previous_state - previous_mean
next_mean = previous_mean + delta_pre / tf.cast(n_next, delta_pre.dtype)
delta_post = previous_state - next_mean
delta_cov = tf.expand_dims(delta_post, -1) * tf.expand_dims(delta_pre, -2)
next_cov = previous_cov + delta_cov
next_mean.set_shape(previous_mean.shape)
next_cov.set_shape(previous_cov.shape)
return n_next, next_state, next_pkr, next_mean, next_cov
if use_input_signature:
input_signature = [
tf.TensorSpec(shape=None, dtype=tf.int32),
tf.TensorSpec(shape=None, dtype=tf.int32),
tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
tf.TensorSpec(shape=[None, total_ndims, total_ndims], dtype=dtype),
]
else:
input_signature = None
# TODO(b/158878248): move the nested function definitions to module top-level.
@tf.function(
input_signature=input_signature,
autograph=False,
jit_compile=jit_compile,
experimental_relax_shapes=experimental_relax_shapes)
def slow_adaptation_interval(num_steps, previous_n, previous_state,
previous_mean, previous_cov):
"""Interval that tunes the mass matrix and step size simultaneously.
This corresponds to window 2 in the Stan HMC parameter tuning scheme.
Args:
num_steps: Number of tuning steps the interval will run.
previous_n: Previous number of tuning steps we have run.
previous_state: Initial state of the tuning interval.
previous_mean: Current estimated posterior mean.
previous_cov: Current estimated posterior covariance matrix.
Returns:
total_n: Total number of tuning steps we have run.
next_state: Last state of the tuning interval.
next_pkr: Kernel result from the TransitionKernel at the end of the
tuning interval.
next_mean: estimated posterior mean after tuning.
next_cov: estimated posterior covariance matrix after tuning.
"""
previous_pkr = hmc_step_size_tuning.bootstrap_results(previous_state)
total_n, next_state, next_pkr, next_mean, next_cov = tf.while_loop(
lambda i, *_: i < num_steps + previous_n,
body_fn_window2,
loop_vars=(previous_n, previous_state, previous_pkr, previous_mean,
previous_cov),
maximum_iterations=num_steps,
parallel_iterations=parallel_iterations)
float_n = tf.cast(total_n, next_cov.dtype)
cov = next_cov / (float_n - 1.)
# Regularization
scaled_cov = (float_n / (float_n + 5.)) * cov
shrinkage = 1e-3 * (5. / (float_n + 5.))
next_cov = scaled_cov + shrinkage
return total_n, next_state, next_pkr, next_mean, next_cov
def trace_fn(_, pkr):
return (
pkr.inner_results.target_log_prob,
pkr.inner_results.leapfrogs_taken,
pkr.inner_results.has_divergence,
pkr.inner_results.energy,
pkr.inner_results.log_accept_ratio,
pkr.inner_results.reach_max_depth,
pkr.inner_results.step_size,
)
@tf.function(autograph=False, jit_compile=jit_compile)
def run_chain(num_results, current_state, previous_kernel_results):
return tfp.mcmc.sample_chain(
num_results=num_results,
num_burnin_steps=0,
current_state=current_state,
previous_kernel_results=previous_kernel_results,
kernel=hmc_inner,
trace_fn=trace_fn,
parallel_iterations=parallel_iterations,
seed=seed_stream())
# Main sampling with tuning routine.
num_steps_tuning_window_schedule0 = tuning_window_schedule[0]
# Window 1 to tune step size
logging.info('Tuning Window 1...')
next_state, _ = fast_adaptation_interval(num_steps_tuning_window_schedule0,
init_state)
next_mean = tf.zeros_like(init_state)
next_cov = tf.zeros(
ps.concat([ps.shape(init_state), ps.shape(init_state)[-1:]], axis=-1),
dtype=dtype)
mean_updater = tf.zeros([total_ndims], dtype=dtype)
diag_updater = tf.ones([total_ndims], dtype=dtype)
# Window 2 to tune mass matrix.
total_n = 0
for i, num_steps in enumerate(tuning_window_schedule[1:-1]):
logging.info('Tuning Window 2 - %s...', i)
if not use_wide_window_expanding_mode:
num_steps = num_steps * 2**i
with tf.control_dependencies([
loc_conditioner.assign(mean_updater, read_value=False),
scale_conditioner.assign(diag_updater, read_value=False)
]):
(total_n, next_state_, _, next_mean_,
next_cov_) = slow_adaptation_interval(num_steps, total_n, next_state,
next_mean, next_cov)
diag_part = tf.linalg.diag_part(next_cov_)
if ps.rank(next_state) > 1:
mean_updater = tf.reduce_mean(next_mean_, axis=0)
diag_updater = tf.math.sqrt(tf.reduce_mean(diag_part, axis=0))
else:
mean_updater = next_mean_
diag_updater = tf.math.sqrt(diag_part)
if use_wide_window_expanding_mode:
next_mean = tf.concat([next_mean_, next_mean_], axis=0)
next_cov = tf.concat([next_cov_, next_cov_], axis=0)
next_state = tf.concat([next_state_, next_state_], axis=0)
else:
next_mean, next_cov, next_state = next_mean_, next_cov_, next_state_
num_steps_tuning_window_schedule3 = tuning_window_schedule[-1]
num_batches = ps.size0(next_state)
if nchains > num_batches:
final_init_state = tf.repeat(
next_state, (nchains + 1) // num_batches, axis=0)[:nchains]
else:
final_init_state = next_state[:nchains]
with tf.control_dependencies([
loc_conditioner.assign(mean_updater, read_value=False),
scale_conditioner.assign(diag_updater, read_value=False)
]):
# Window 3 step size tuning
logging.info('Tuning Window 3...')
final_tuned_state, final_pkr = fast_adaptation_interval(
num_steps_tuning_window_schedule3, final_init_state)
# Final samples
logging.info('Sampling...')
nuts_samples, diagnostic = run_chain(num_samples, final_tuned_state,
final_pkr.inner_results)
return nuts_samples, diagnostic, conditioning_bijector