def _sample_posterior()

in discussion/turnkey_inference_candidate/window_tune_nuts_sampling.py [0:0]


def _sample_posterior(target_log_prob_unconstrained,
                      prior_samples_unconstrained,
                      init_state=None,
                      num_samples=500,
                      nchains=4,
                      init_nchains=1,
                      target_accept_prob=.8,
                      max_tree_depth=9,
                      use_scaled_init=True,
                      tuning_window_schedule=(75, 25, 25, 25, 25, 25, 50),
                      use_wide_window_expanding_mode=True,
                      seed=None,
                      parallel_iterations=10,
                      jit_compile=True,
                      use_input_signature=False,
                      experimental_relax_shapes=False):
  """MCMC sampling with HMC/NUTS using an expanding epoch tuning scheme."""

  seed_stream = tfp.util.SeedStream(seed, 'window_tune_nuts_sampling')
  rv_rank = ps.rank(prior_samples_unconstrained)
  assert rv_rank == 2
  total_ndims = ps.shape(prior_samples_unconstrained)[-1]
  dtype = prior_samples_unconstrained.dtype

  # TODO(b/158878248): explore option to for user to control the
  # parameterization of conditioning_bijector.
  # TODO(b/158878248): right now, we use 2 tf.Variable to initialize a scaling
  # bijector, and update the underlying value at the end of each warmup window.
  # It might be faster to rewrite it into a functional style (with a small
  # additional compilation cost).
  loc_conditioner = tf.Variable(
      tf.zeros([total_ndims], dtype=dtype), name='loc_conditioner')
  scale_conditioner = tf.Variable(
      tf.ones([total_ndims], dtype=dtype), name='scale_conditioner')

  # Start with Identity Covariance matrix
  scale = tf.linalg.LinearOperatorDiag(
      diag=scale_conditioner,
      is_non_singular=True,
      is_self_adjoint=True,
      is_positive_definite=True)
  conditioning_bijector = tfb.Shift(shift=loc_conditioner)(
      tfb.ScaleMatvecLinearOperator(scale))

  if init_state is None:
    # Start at uniform random [-1, 1] around the prior mean in latent space
    init_state_uniform = tf.random.uniform(
        [init_nchains, total_ndims], dtype=dtype, seed=seed_stream()) * 2. - 1.
    if use_scaled_init:
      prior_z_mean = tf.math.reduce_mean(prior_samples_unconstrained, axis=0)
      prior_z_std = tf.math.reduce_std(prior_samples_unconstrained, axis=0)
      init_state = init_state_uniform * prior_z_std + prior_z_mean
    else:
      init_state = init_state_uniform

  # The denominator is the O(N^0.25) scaling from Beskos et al. 2010. The
  # numerator corresponds to the trajectory length. Candidate value includs: 1,
  # 1.57 (pi / 2). We use a conservately small value here (0.25).
  init_step_size = tf.constant(0.25 / (total_ndims**0.25), dtype=dtype)

  hmc_inner = tfp.mcmc.TransformedTransitionKernel(
      tfp.mcmc.NoUTurnSampler(
          target_log_prob_fn=target_log_prob_unconstrained,
          step_size=init_step_size,
          max_tree_depth=max_tree_depth,
          parallel_iterations=parallel_iterations,
      ), conditioning_bijector)

  hmc_step_size_tuning = tfp.mcmc.DualAveragingStepSizeAdaptation(
      inner_kernel=hmc_inner,
      num_adaptation_steps=max(tuning_window_schedule),
      target_accept_prob=target_accept_prob)

  if use_input_signature:
    input_signature = [
        tf.TensorSpec(shape=None, dtype=tf.int32),
        tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
    ]
  else:
    input_signature = None

  # TODO(b/158878248): move the nested function definitions to module top-level.
  @tf.function(
      input_signature=input_signature,
      autograph=False,
      jit_compile=jit_compile,
      experimental_relax_shapes=experimental_relax_shapes)
  def fast_adaptation_interval(num_steps, previous_state):
    """Step size only adaptation interval.

    This corresponds to window 1 and window 3 in the Stan HMC parameter
    tuning scheme.

    Args:
      num_steps: Number of tuning steps the interval will run.
      previous_state: Initial state of the tuning interval.

    Returns:
      last_state: Last state of the tuning interval.
      last_pkr: Kernel result from the TransitionKernel at the end of the
        tuning interval.
    """

    def body_fn(i, state, pkr):
      next_state, next_pkr = hmc_step_size_tuning.one_step(state, pkr)
      return i + 1, next_state, next_pkr

    current_pkr = hmc_step_size_tuning.bootstrap_results(previous_state)
    _, last_state, last_pkr = tf.while_loop(
        lambda i, *_: i < num_steps,
        body_fn,
        loop_vars=(0, previous_state, current_pkr),
        maximum_iterations=num_steps,
        parallel_iterations=parallel_iterations)
    return last_state, last_pkr

  def body_fn_window2(
      i, previous_state, previous_pkr, previous_mean, previous_cov):
    """Take one MCMC step and update the step size and mass matrix."""
    next_state, next_pkr = hmc_step_size_tuning.one_step(
        previous_state, previous_pkr)
    n_next = i + 1
    delta_pre = previous_state - previous_mean
    next_mean = previous_mean + delta_pre / tf.cast(n_next, delta_pre.dtype)
    delta_post = previous_state - next_mean
    delta_cov = tf.expand_dims(delta_post, -1) * tf.expand_dims(delta_pre, -2)
    next_cov = previous_cov + delta_cov

    next_mean.set_shape(previous_mean.shape)
    next_cov.set_shape(previous_cov.shape)
    return n_next, next_state, next_pkr, next_mean, next_cov

  if use_input_signature:
    input_signature = [
        tf.TensorSpec(shape=None, dtype=tf.int32),
        tf.TensorSpec(shape=None, dtype=tf.int32),
        tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
        tf.TensorSpec(shape=[None, total_ndims], dtype=dtype),
        tf.TensorSpec(shape=[None, total_ndims, total_ndims], dtype=dtype),
    ]
  else:
    input_signature = None

  # TODO(b/158878248): move the nested function definitions to module top-level.
  @tf.function(
      input_signature=input_signature,
      autograph=False,
      jit_compile=jit_compile,
      experimental_relax_shapes=experimental_relax_shapes)
  def slow_adaptation_interval(num_steps, previous_n, previous_state,
                               previous_mean, previous_cov):
    """Interval that tunes the mass matrix and step size simultaneously.

    This corresponds to window 2 in the Stan HMC parameter tuning scheme.

    Args:
      num_steps: Number of tuning steps the interval will run.
      previous_n: Previous number of tuning steps we have run.
      previous_state: Initial state of the tuning interval.
      previous_mean: Current estimated posterior mean.
      previous_cov: Current estimated posterior covariance matrix.

    Returns:
      total_n: Total number of tuning steps we have run.
      next_state: Last state of the tuning interval.
      next_pkr: Kernel result from the TransitionKernel at the end of the
        tuning interval.
      next_mean: estimated posterior mean after tuning.
      next_cov: estimated posterior covariance matrix after tuning.
    """
    previous_pkr = hmc_step_size_tuning.bootstrap_results(previous_state)
    total_n, next_state, next_pkr, next_mean, next_cov = tf.while_loop(
        lambda i, *_: i < num_steps + previous_n,
        body_fn_window2,
        loop_vars=(previous_n, previous_state, previous_pkr, previous_mean,
                   previous_cov),
        maximum_iterations=num_steps,
        parallel_iterations=parallel_iterations)
    float_n = tf.cast(total_n, next_cov.dtype)
    cov = next_cov / (float_n - 1.)

    # Regularization
    scaled_cov = (float_n / (float_n + 5.)) * cov
    shrinkage = 1e-3 * (5. / (float_n + 5.))
    next_cov = scaled_cov + shrinkage

    return total_n, next_state, next_pkr, next_mean, next_cov

  def trace_fn(_, pkr):
    return (
        pkr.inner_results.target_log_prob,
        pkr.inner_results.leapfrogs_taken,
        pkr.inner_results.has_divergence,
        pkr.inner_results.energy,
        pkr.inner_results.log_accept_ratio,
        pkr.inner_results.reach_max_depth,
        pkr.inner_results.step_size,
    )

  @tf.function(autograph=False, jit_compile=jit_compile)
  def run_chain(num_results, current_state, previous_kernel_results):
    return tfp.mcmc.sample_chain(
        num_results=num_results,
        num_burnin_steps=0,
        current_state=current_state,
        previous_kernel_results=previous_kernel_results,
        kernel=hmc_inner,
        trace_fn=trace_fn,
        parallel_iterations=parallel_iterations,
        seed=seed_stream())

  # Main sampling with tuning routine.
  num_steps_tuning_window_schedule0 = tuning_window_schedule[0]

  # Window 1 to tune step size
  logging.info('Tuning Window 1...')
  next_state, _ = fast_adaptation_interval(num_steps_tuning_window_schedule0,
                                           init_state)

  next_mean = tf.zeros_like(init_state)
  next_cov = tf.zeros(
      ps.concat([ps.shape(init_state), ps.shape(init_state)[-1:]], axis=-1),
      dtype=dtype)

  mean_updater = tf.zeros([total_ndims], dtype=dtype)
  diag_updater = tf.ones([total_ndims], dtype=dtype)

  # Window 2 to tune mass matrix.
  total_n = 0
  for i, num_steps in enumerate(tuning_window_schedule[1:-1]):
    logging.info('Tuning Window 2 - %s...', i)
    if not use_wide_window_expanding_mode:
      num_steps = num_steps * 2**i
    with tf.control_dependencies([
        loc_conditioner.assign(mean_updater, read_value=False),
        scale_conditioner.assign(diag_updater, read_value=False)
    ]):
      (total_n, next_state_, _, next_mean_,
       next_cov_) = slow_adaptation_interval(num_steps, total_n, next_state,
                                             next_mean, next_cov)
      diag_part = tf.linalg.diag_part(next_cov_)
      if ps.rank(next_state) > 1:
        mean_updater = tf.reduce_mean(next_mean_, axis=0)
        diag_updater = tf.math.sqrt(tf.reduce_mean(diag_part, axis=0))
      else:
        mean_updater = next_mean_
        diag_updater = tf.math.sqrt(diag_part)

      if use_wide_window_expanding_mode:
        next_mean = tf.concat([next_mean_, next_mean_], axis=0)
        next_cov = tf.concat([next_cov_, next_cov_], axis=0)
        next_state = tf.concat([next_state_, next_state_], axis=0)
      else:
        next_mean, next_cov, next_state = next_mean_, next_cov_, next_state_

  num_steps_tuning_window_schedule3 = tuning_window_schedule[-1]
  num_batches = ps.size0(next_state)
  if nchains > num_batches:
    final_init_state = tf.repeat(
        next_state, (nchains + 1) // num_batches, axis=0)[:nchains]
  else:
    final_init_state = next_state[:nchains]

  with tf.control_dependencies([
      loc_conditioner.assign(mean_updater, read_value=False),
      scale_conditioner.assign(diag_updater, read_value=False)
  ]):
    # Window 3 step size tuning
    logging.info('Tuning Window 3...')
    final_tuned_state, final_pkr = fast_adaptation_interval(
        num_steps_tuning_window_schedule3, final_init_state)

    # Final samples
    logging.info('Sampling...')
    nuts_samples, diagnostic = run_chain(num_samples, final_tuned_state,
                                         final_pkr.inner_results)

  return nuts_samples, diagnostic, conditioning_bijector