def one_step()

in tensorflow_probability/python/mcmc/replica_exchange_mc.py [0:0]


  def one_step(self, current_state, previous_kernel_results, seed=None):
    """Takes one step of the TransitionKernel.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s).
      previous_kernel_results: A (possibly nested) `tuple`, `namedtuple` or
        `list` of `Tensor`s representing internal calculations made within the
        previous call to this function (or as returned by `bootstrap_results`).
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      next_state: `Tensor` or Python `list` of `Tensor`s representing the
        next state(s) of the Markov chain(s).
      kernel_results: A (possibly nested) `tuple`, `namedtuple` or `list` of
        `Tensor`s representing internal calculations made within this function.
        This inculdes replica states.
    """

    # The code below propagates one step states of shape
    #  [n_replica] + batch_shape + event_shape.
    #
    # The step is done in three parts:
    #  1) Call one_step to transition states via a tempered version of
    #     self.target_log_prob_fn (see _replica_target_log_prob).
    #  2) Permute values in states
    #  3) Update state-dependent values, such as log_probs.
    #
    # We chose to swap states, rather than temperatures, because...
    # (i)  If swapping temperatures, you *still* have to swap log_probs to
    #      determine acceptance, as well as states (for kernel results).
    #      So it's just as difficult to swap temperatures.
    # (ii) If swapping temperatures, you have to take care to swap any user-
    #      supplied temperature related things (like step size).
    #      A-priori, we don't know what else will need to be swapped!
    # (iii)In both cases, the kernel results need to be updated in a non-trivial
    #      manner....so we either special-case, or use bootstrap.

    with tf.name_scope(mcmc_util.make_name(self.name, 'remc', 'one_step')):
      # Force a read in case the `inverse_temperatures` is a `tf.Variable`.
      inverse_temperatures = tf.convert_to_tensor(
          previous_kernel_results.inverse_temperatures,
          name='inverse_temperatures')

      target_log_prob_for_inner_kernel = _make_replica_target_log_prob_fn(
          target_log_prob_fn=self.target_log_prob_fn,
          inverse_temperatures=inverse_temperatures,
          untempered_log_prob_fn=self.untempered_log_prob_fn,
          tempered_log_prob_fn=self.tempered_log_prob_fn,
      )
      # TODO(b/159636942): Clean up the helpful error msg after 2020-11-10.
      try:
        inner_kernel = self.make_kernel_fn(  # pylint: disable=not-callable
            target_log_prob_for_inner_kernel)
      except TypeError as e:
        if 'argument' not in str(e):
          raise
        raise TypeError(
            '`ReplicaExchangeMC`s `make_kernel_fn` no longer receives a `seed` '
            'argument. `TransitionKernel` instances now receive seeds via '
            '`one_step`.')

      seed = samplers.sanitize_seed(seed)  # Retain for diagnostics.
      inner_seed, swap_seed, logu_seed = samplers.split_seed(seed, n=3)
      # Step the inner TransitionKernel.
      [
          pre_swap_replica_states,
          pre_swap_replica_results,
      ] = inner_kernel.one_step(
          previous_kernel_results.post_swap_replica_states,
          previous_kernel_results.post_swap_replica_results,
          seed=inner_seed)

      pre_swap_replica_target_log_prob = _get_field(
          # These are tempered log probs (have been divided by temperature).
          pre_swap_replica_results, 'target_log_prob')

      dtype = pre_swap_replica_target_log_prob.dtype
      replica_and_batch_shape = ps.shape(
          pre_swap_replica_target_log_prob)
      batch_shape = replica_and_batch_shape[1:]
      replica_and_batch_rank = ps.rank(
          pre_swap_replica_target_log_prob)
      num_replica = ps.size0(inverse_temperatures)

      inverse_temperatures = bu.left_justified_broadcast_to(
          inverse_temperatures, replica_and_batch_shape)

      # Now that each replica has done one_step, it is time to consider swaps.

      # swap.shape = [n_replica], and is a "once only" permutation, meaning it
      # is achievable by a sequence of pairwise permutations, where each element
      # is moved at most once.
      # E.g. if swaps = [1, 0, 2], we will consider swapping temperatures 0 and
      # 1, keeping 2 fixed.  This exact same swap is considered for *every*
      # batch member.  Of course some batch members may accept and some reject.
      try:
        swaps = tf.cast(
            self.swap_proposal_fn(  # pylint: disable=not-callable
                num_replica,
                batch_shape=batch_shape,
                seed=swap_seed,
                step_count=previous_kernel_results.step_count),
            dtype=tf.int32)
      except TypeError as e:
        if 'step_count' not in str(e):
          raise
        warnings.warn(
            'The `swap_proposal_fn` given to ReplicaExchangeMC did not accept '
            'the `step_count` argument. Falling back to omitting the '
            'argument. This fallback will be removed after 24-Oct-2020.')
        swaps = tf.cast(
            self.swap_proposal_fn(  # pylint: disable=not-callable
                num_replica,
                batch_shape=batch_shape,
                seed=swap_seed),
            dtype=tf.int32)

      null_swaps = bu.left_justified_expand_dims_like(
          tf.range(num_replica, dtype=swaps.dtype), swaps)
      swaps = _maybe_embed_swaps_validation(swaps, null_swaps,
                                            self.validate_args)

      # Un-temper the log probs for use in the swap acceptance ratio.
      if self.tempered_log_prob_fn is None:
        # Efficient way of re-evaluating target_log_prob_fn on the
        # pre_swap_replica_states.
        untempered_negative_energy_ignoring_ulp = (
            # Since untempered_log_prob_fn is None, we may assume
            # inverse_temperatures > 0 (else the target is improper).
            pre_swap_replica_target_log_prob / inverse_temperatures)
      else:
        # The untempered_log_prob_fn does not factor into the acceptance ratio.
        # Proof: Suppose the tempered target is
        #   p_k(x) = f(x)^{beta_k} g(x),
        # So f(x) is tempered, and g(x) is not.  Then, the acceptance ratio for
        # a 1 <--> 2 swap is...
        #   (p_1(x_2) p_2(x_1)) / (p_1(x_1) p_2(x_2))
        # which depends only on f(x), since terms involving g(x) cancel.
        untempered_negative_energy_ignoring_ulp = self.tempered_log_prob_fn(
            *pre_swap_replica_states)

      # Since `swaps` is its own inverse permutation we automatically know the
      # swap counterpart: range(num_replica). We use this idea to compute the
      # acceptance in a vectorized manner at the cost of wasting roughly half
      # our computation. Although we could use `unique` to solve this problem,
      # we expect the cost of `unique` to be higher than the dozens of wasted
      # arithmetic calculations. Worse, it'd mean we need dynamic sized Tensors
      # (eg, using `tf.where(bool)`) and so we wouldn't be able to XLA compile.

      # Note: diffs would normally be "proposed - current" however energy is
      # flipped since `energy == -log_prob`.
      # Note: The untempered_log_prob_fn (if provided) is not included in
      # untempered_pre_swap_replica_target_log_prob, and hence does not factor
      # into energy_diff. Why? Because, it cancels out in the acceptance ratio.
      energy_diff = (
          untempered_negative_energy_ignoring_ulp -
          mcmc_util.index_remapping_gather(
              untempered_negative_energy_ignoring_ulp,
              swaps, name='gather_swap_tlp'))
      swapped_inverse_temperatures = mcmc_util.index_remapping_gather(
          inverse_temperatures, swaps, name='gather_swap_temps')
      inverse_temp_diff = swapped_inverse_temperatures - inverse_temperatures

      # If i and j are swapping, log_accept_ratio[] i and j are equal.
      log_accept_ratio = (
          energy_diff * bu.left_justified_expand_dims_to(
              inverse_temp_diff, replica_and_batch_rank))

      log_accept_ratio = tf.where(
          tf.math.is_finite(log_accept_ratio),
          log_accept_ratio, tf.constant(-np.inf, dtype=dtype))

      # Produce log[Uniform] draws that are identical at swapped indices.
      log_uniform = tf.math.log(
          samplers.uniform(shape=replica_and_batch_shape,
                           dtype=dtype,
                           seed=logu_seed))
      anchor_swaps = tf.minimum(swaps, null_swaps)
      log_uniform = mcmc_util.index_remapping_gather(log_uniform, anchor_swaps)

      is_swap_accepted_mask = tf.less(
          log_uniform,
          log_accept_ratio,
          name='is_swap_accepted_mask')

      def _swap_tensor(x):
        return mcmc_util.choose(
            is_swap_accepted_mask,
            mcmc_util.index_remapping_gather(x, swaps), x)

      post_swap_replica_states = [
          _swap_tensor(s) for s in pre_swap_replica_states]

      expanded_null_swaps = bu.left_justified_broadcast_to(
          null_swaps, replica_and_batch_shape)
      is_swap_proposed = _compute_swap_notmatrix(
          # Broadcast both so they have shape [num_replica] + batch_shape.
          # This (i) makes them have same shape as is_swap_accepted, and
          # (ii) keeps shape consistent if someday swaps has a batch shape.
          expanded_null_swaps,
          bu.left_justified_broadcast_to(swaps, replica_and_batch_shape))

      # To get is_swap_accepted in ordered position, we use
      # _compute_swap_notmatrix on current and next replica positions.
      post_swap_replica_position = _swap_tensor(expanded_null_swaps)

      is_swap_accepted = _compute_swap_notmatrix(
          post_swap_replica_position,
          expanded_null_swaps)

      if self._state_includes_replicas:
        post_swap_states = post_swap_replica_states
      else:
        post_swap_states = [s[0] for s in post_swap_replica_states]

      post_swap_replica_results = _set_swapped_fields_to_nan(
          _swap_log_prob_and_maybe_grads(
              pre_swap_replica_results,
              post_swap_replica_states,
              inner_kernel))

      if mcmc_util.is_list_like(current_state):
        # We *always* canonicalize the states in the kernel results.
        states = post_swap_states
      else:
        states = post_swap_states[0]

      post_swap_kernel_results = ReplicaExchangeMCKernelResults(
          post_swap_replica_states=post_swap_replica_states,
          pre_swap_replica_results=pre_swap_replica_results,
          post_swap_replica_results=post_swap_replica_results,
          is_swap_proposed=is_swap_proposed,
          is_swap_accepted=is_swap_accepted,
          is_swap_proposed_adjacent=_sub_diag(is_swap_proposed),
          is_swap_accepted_adjacent=_sub_diag(is_swap_accepted),
          # Store the original pkr.inverse_temperatures in case its a
          # `tf.Variable`.
          inverse_temperatures=previous_kernel_results.inverse_temperatures,
          swaps=swaps,
          step_count=previous_kernel_results.step_count + 1,
          seed=seed,
          potential_energy=-untempered_negative_energy_ignoring_ulp,
      )

      return states, post_swap_kernel_results