def sample_sequential_monte_carlo()

in tensorflow_probability/python/experimental/mcmc/sample_sequential_monte_carlo.py [0:0]


def sample_sequential_monte_carlo(
    prior_log_prob_fn,
    likelihood_log_prob_fn,
    current_state,
    min_num_steps=2,
    max_num_steps=25,
    max_stage=100,
    make_kernel_fn=make_rwmh_kernel_fn,
    tuning_fn=simple_heuristic_tuning,
    make_tempered_target_log_prob_fn=default_make_tempered_target_log_prob_fn,
    resample_fn=weighted_resampling.resample_systematic,
    ess_threshold_ratio=0.5,
    parallel_iterations=10,
    seed=None,
    name=None):
  """Runs Sequential Monte Carlo to sample from the posterior distribution.

  This function uses an MCMC transition operator (e.g., Hamiltonian Monte Carlo)
  to sample from a series of distributions that slowly interpolates between
  an initial 'prior' distribution:

    `exp(prior_log_prob_fn(x))`

  and the target 'posterior' distribution:

    `exp(prior_log_prob_fn(x) + target_log_prob_fn(x))`,

  by mutating a collection of MC samples (i.e., particles). The approach is also
  known as Particle Filter in some literature. The current implemenetation is
  largely based on Del Moral et al [1], which adapts the tempering sequence
  adaptively (base on the effective sample size) and the scaling of the mutation
  kernel (base on the sample covariance of the particles) at each stage.

  Args:
    prior_log_prob_fn: Python callable that returns the log density of the
      prior distribution.
    likelihood_log_prob_fn: Python callable which takes an argument like
      `current_state` (or `*current_state` if it's a list) and returns its
      (possibly unnormalized) log-density under the likelihood distribution.
    current_state: Nested structure of `Tensor`s, each of shape
      `concat([[num_particles, b1, ..., bN], latent_part_event_shape])`, where
      `b1, ..., bN` are optional batch dimensions. Each batch represents an
      independent SMC run.
    min_num_steps: The minimal number of kernel transition steps in one mutation
      of the MC samples.
    max_num_steps: The maximum number of kernel transition steps in one mutation
      of the MC samples. Note that the actual number of steps in one mutation is
      tuned during sampling and likely lower than the max_num_step.
    max_stage: Integer number of the stage for increasing the temperature
      from 0 to 1.
    make_kernel_fn: Python `callable` which returns a `TransitionKernel`-like
      object. Must take one argument representing the `TransitionKernel`'s
      `target_log_prob_fn`. The `target_log_prob_fn` argument represents the
      `TransitionKernel`'s target log distribution.  Note:
        `sample_sequential_monte_carlo` creates a new `target_log_prob_fn` which
        is an interpolation between the supplied `target_log_prob_fn` and
        `proposal_log_prob_fn`; it is this interpolated function which is used
        as an argument to `make_kernel_fn`.
    tuning_fn: Python `callable` which takes the number of steps, the log
      scaling, and the log acceptance ratio from the last mutation and output
      the number of steps and log scaling for the next mutation.
    make_tempered_target_log_prob_fn: Python `callable` that takes the
      `prior_log_prob_fn`, `likelihood_log_prob_fn`, and `inverse_temperatures`
      and creates a `target_log_prob_fn` `callable` that pass to
      `make_kernel_fn`.
    resample_fn: Python `callable` to generate the indices of resampled
      particles, given their weights. Generally, one of
      `tfp.experimental.mcmc.resample_independent` or
      `tfp.experimental.mcmc.resample_systematic`, or any function with the same
      signature.
      Default value: `tfp.experimental.mcmc.resample_systematic`.
    ess_threshold_ratio: Target ratio for effective sample size.
    parallel_iterations: The number of iterations allowed to run in parallel. It
      must be a positive integer. See `tf.while_loop` for more details.
    seed: Python integer or TFP seedstream to seed the random number generator.
    name: Python `str` name prefixed to Ops created by this function.
      Default value: `None` (i.e., 'sample_sequential_monte_carlo').

  Returns:
    n_stage: Number of the mutation stage SMC ran.
    final_state: `Tensor` or Python `list` of `Tensor`s representing the
      final state(s) of the Markov chain(s). The output are the posterior
      samples.
    final_kernel_results: `collections.namedtuple` of internal calculations used
      to advance the chain.

  #### References

  [1] Del Moral, Pierre, Arnaud Doucet, and Ajay Jasra. An adaptive sequential
      Monte Carlo method for approximate Bayesian computation.
      _Statistics and Computing_, 22.5(1009-1020), 2012.

  """

  with tf.name_scope(name or 'sample_sequential_monte_carlo'):
    is_seeded = seed is not None
    seed = samplers.sanitize_seed(seed, salt='mcmc.sample_smc')

    unwrap_state_list = not tf.nest.is_nested(current_state)
    if unwrap_state_list:
      current_state = [current_state]
    current_state = [
        tf.convert_to_tensor(s, dtype_hint=tf.float32) for s in current_state
    ]

    # Initial preprocessing at Stage 0
    likelihood_log_prob = likelihood_log_prob_fn(*current_state)

    likelihood_rank = ps.rank(likelihood_log_prob)
    dimension = ps.reduce_sum(
        [ps.reduce_prod(ps.shape(x)[likelihood_rank:]) for x in current_state])

    # We infer the particle shapes from the resulting likelihood:
    # [num_particles, b1, ..., bN]
    particle_shape = ps.shape(likelihood_log_prob)
    num_particles, batch_shape = particle_shape[0], particle_shape[1:]
    effective_sample_size_threshold = tf.cast(
        num_particles * ess_threshold_ratio, tf.int32)

    # TODO(b/152412213): Revisit this default parameter.
    # Default to the optimal scaling of a random walk kernel for a d-dimensional
    # normal distributed targets: 2.38 ** 2 / d.
    # For more detail see:
    # Roberts GO, Gelman A, Gilks WR. Weak convergence and optimal scaling of
    # random walk Metropolis algorithms. _The annals of applied probability_.
    # 1997;7(1):110-20.
    scale_start = (
        tf.constant(2.38**2, dtype=likelihood_log_prob.dtype) /
        tf.constant(dimension, dtype=likelihood_log_prob.dtype))

    inverse_temperature = tf.zeros(batch_shape, dtype=likelihood_log_prob.dtype)
    scalings = ps.ones_like(likelihood_log_prob) * ps.minimum(scale_start, 1.)
    kernel = make_kernel_fn(
        make_tempered_target_log_prob_fn(
            prior_log_prob_fn,
            likelihood_log_prob_fn,
            inverse_temperature),
        current_state,
        scalings)
    pkr = kernel.bootstrap_results(current_state)
    _, kernel_target_log_prob = gather_mh_like_result(pkr)

    particle_info = ParticleInfo(
        log_accept_prob=ps.zeros_like(likelihood_log_prob),
        log_scalings=tf.math.log(scalings),
        tempered_log_prob=kernel_target_log_prob,
        likelihood_log_prob=likelihood_log_prob,
    )

    current_pkr = SMCResults(
        num_steps=tf.convert_to_tensor(
            max_num_steps, dtype=tf.int32, name='num_steps'),
        inverse_temperature=inverse_temperature,
        log_marginal_likelihood=tf.zeros_like(inverse_temperature),
        particle_info=particle_info
    )

    def update_weights_temperature(inverse_temperature, likelihood_log_prob):
      """Calculate the next inverse temperature and update weights."""
      likelihood_diff = likelihood_log_prob - tf.reduce_max(
          likelihood_log_prob, axis=0)

      def _body_fn(new_beta, upper_beta, lower_beta, eff_size, log_weights):
        """One iteration of the temperature and weight update."""
        new_beta = (lower_beta + upper_beta) / 2.0
        log_weights = (new_beta - inverse_temperature) * likelihood_diff
        log_weights_norm = tf.math.log_softmax(log_weights, axis=0)
        eff_size = tf.cast(
            tf.exp(-tf.math.reduce_logsumexp(2 * log_weights_norm, axis=0)),
            tf.int32)
        upper_beta = tf.where(
            eff_size < effective_sample_size_threshold,
            new_beta, upper_beta)
        lower_beta = tf.where(
            eff_size < effective_sample_size_threshold,
            lower_beta, new_beta)
        return new_beta, upper_beta, lower_beta, eff_size, log_weights

      def _cond_fn(new_beta, upper_beta, lower_beta, eff_size, *_):  # pylint: disable=unused-argument
        # TODO(junpenglao): revisit threshold below to be dtype specific.
        threshold = 1e-6
        return (
            tf.math.reduce_any(upper_beta - lower_beta > threshold) &
            tf.math.reduce_any(eff_size != effective_sample_size_threshold)
            )

      (new_beta, upper_beta, lower_beta, eff_size, log_weights) = tf.while_loop(  # pylint: disable=unused-variable
          cond=_cond_fn,
          body=_body_fn,
          loop_vars=(
              tf.zeros_like(inverse_temperature),
              tf.fill(
                  ps.shape(inverse_temperature),
                  tf.constant(2, inverse_temperature.dtype)),
              inverse_temperature,
              tf.zeros_like(inverse_temperature, dtype=tf.int32),
              tf.zeros_like(likelihood_diff)),
          parallel_iterations=parallel_iterations
          )

      log_weights = tf.where(new_beta < 1.,
                             log_weights,
                             (1. - inverse_temperature) * likelihood_diff)
      marginal_loglike_ = reduce_logmeanexp(
          (new_beta - inverse_temperature) * likelihood_log_prob, axis=0)
      new_inverse_temperature = tf.clip_by_value(new_beta, 0., 1.)

      return marginal_loglike_, new_inverse_temperature, log_weights

    def mutate(
        current_state,
        log_scalings,
        num_steps,
        inverse_temperature):
      """Mutate the state using a Transition kernel."""
      with tf.name_scope('mutate_states'):
        scalings = tf.exp(log_scalings)
        kernel = make_kernel_fn(
            make_tempered_target_log_prob_fn(
                prior_log_prob_fn,
                likelihood_log_prob_fn,
                inverse_temperature),
            current_state,
            scalings)
        pkr = kernel.bootstrap_results(current_state)
        kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)

        def mutate_onestep(i, seed, state, pkr, log_accept_prob_sum):
          iter_seed, next_seed = (
              samplers.split_seed(seed) if is_seeded else (None, seed))

          one_step_kwargs = dict(seed=iter_seed) if is_seeded else {}
          next_state, next_kernel_results = kernel.one_step(
              state, pkr, **one_step_kwargs)
          kernel_log_accept_ratio, _ = gather_mh_like_result(pkr)
          log_accept_prob = tf.minimum(kernel_log_accept_ratio, 0.)
          log_accept_prob_sum = log_add_exp(log_accept_prob_sum,
                                            log_accept_prob)
          return [
              i + 1, next_seed, next_state, next_kernel_results,
              log_accept_prob_sum
          ]

        (
            _, _,
            next_state,
            next_kernel_results,
            log_accept_prob_sum
        ) = tf.while_loop(
            cond=lambda i, *args: i < num_steps,
            body=mutate_onestep,
            loop_vars=(
                tf.zeros([], dtype=tf.int32),
                seed,
                current_state,
                pkr,
                # we accumulate the acceptance probability in log space.
                tf.fill(
                    ps.shape(kernel_log_accept_ratio),
                    tf.constant(-np.inf, kernel_log_accept_ratio.dtype))
                ),
            parallel_iterations=parallel_iterations
            )
        _, kernel_target_log_prob = gather_mh_like_result(next_kernel_results)
        avg_log_accept_prob_per_particle = log_accept_prob_sum - tf.math.log(
            tf.cast(num_steps + 1, log_accept_prob_sum.dtype))
        return (next_state,
                avg_log_accept_prob_per_particle,
                kernel_target_log_prob)

    # One SMC steps.
    def smc_body_fn(stage, state, smc_kernel_result):
      """Run one stage of SMC with constant temperature."""
      (
          new_marginal,
          new_inv_temperature,
          log_weights
      ) = update_weights_temperature(
          smc_kernel_result.inverse_temperature,
          smc_kernel_result.particle_info.likelihood_log_prob)
      # TODO(b/152412213) Use a tf.scan to better collect debug info.
      if PRINT_DEBUG:
        tf.print(
            'Stage:', stage,
            'Beta:', new_inv_temperature,
            'n_steps:', smc_kernel_result.num_steps,
            'accept:', tf.exp(reduce_logmeanexp(
                smc_kernel_result.particle_info.log_accept_prob, axis=0)),
            'scaling:', tf.exp(reduce_logmeanexp(
                smc_kernel_result.particle_info.log_scalings, axis=0))
            )
      (resampled_state,
       resampled_particle_info), _, _ = weighted_resampling.resample(
           particles=(state, smc_kernel_result.particle_info),
           log_weights=log_weights,
           resample_fn=resample_fn,
           seed=seed)
      next_num_steps, next_log_scalings = tuning_fn(
          smc_kernel_result.num_steps,
          resampled_particle_info.log_scalings,
          resampled_particle_info.log_accept_prob)
      # Skip tuning at stage 0.
      next_num_steps = tf.where(stage == 0,
                                smc_kernel_result.num_steps,
                                next_num_steps)
      next_log_scalings = tf.where(stage == 0,
                                   resampled_particle_info.log_scalings,
                                   next_log_scalings)
      next_num_steps = tf.clip_by_value(
          next_num_steps, min_num_steps, max_num_steps)

      next_state, log_accept_prob, tempered_log_prob = mutate(
          resampled_state,
          next_log_scalings,
          next_num_steps,
          new_inv_temperature)
      next_pkr = SMCResults(
          num_steps=next_num_steps,
          inverse_temperature=new_inv_temperature,
          log_marginal_likelihood=(new_marginal +
                                   smc_kernel_result.log_marginal_likelihood),
          particle_info=ParticleInfo(
              log_accept_prob=log_accept_prob,
              log_scalings=next_log_scalings,
              tempered_log_prob=tempered_log_prob,
              likelihood_log_prob=likelihood_log_prob_fn(*next_state),
          ))
      return stage + 1, next_state, next_pkr

    (
        n_stage,
        final_state,
        final_kernel_results
    ) = tf.while_loop(
        cond=lambda i, state, pkr: (  # pylint: disable=g-long-lambda
            (i < max_stage) &
            tf.reduce_any(pkr.inverse_temperature < 1.)),
        body=smc_body_fn,
        loop_vars=(
            tf.zeros([], dtype=tf.int32),
            current_state,
            current_pkr),
        parallel_iterations=parallel_iterations
        )
    if unwrap_state_list:
      final_state = final_state[0]
    return n_stage, final_state, final_kernel_results