def one_step()

in tensorflow_probability/python/experimental/mcmc/elliptical_slice_sampler.py [0:0]


  def one_step(self, current_state, previous_kernel_results, seed=None):
    """Runs one iteration of the Elliptical Slice Sampler.

    Args:
      current_state: `Tensor` or Python `list` of `Tensor`s representing the
        current state(s) of the Markov chain(s). The first `r` dimensions
        index independent chains,
        `r = tf.rank(log_likelihood_fn(*normal_sampler_fn()))`.
      previous_kernel_results: `collections.namedtuple` containing `Tensor`s
        representing values from previous calls to this function (or from the
        `bootstrap_results` function.)
      seed: PRNG seed; see `tfp.random.sanitize_seed` for details.

    Returns:
      next_state: Tensor or Python list of `Tensor`s representing the state(s)
        of the Markov chain(s) after taking exactly one step. Has same type and
        shape as `current_state`.
      kernel_results: `collections.namedtuple` of internal calculations used to
        advance the chain.

    Raises:
      TypeError: if `not log_likelihood.dtype.is_floating`.
    """
    with tf.name_scope(
        mcmc_util.make_name(self.name, 'elliptical_slice', 'one_step')):
      with tf.name_scope('initialize'):
        [
            init_state_parts,
            init_log_likelihood
        ] = _prepare_args(
            self.log_likelihood_fn,
            current_state,
            previous_kernel_results.log_likelihood)

      seed = samplers.sanitize_seed(seed)  # Unsalted, for kernel results.
      normal_seed, u_seed, angle_seed, loop_seed = samplers.split_seed(
          seed, n=4, salt='elliptical_slice_sampler')
      normal_samples = self.normal_sampler_fn(normal_seed)  # pylint: disable=not-callable
      normal_samples = list(normal_samples) if mcmc_util.is_list_like(
          normal_samples) else [normal_samples]
      u = samplers.uniform(
          shape=tf.shape(init_log_likelihood),
          seed=u_seed,
          dtype=init_log_likelihood.dtype.base_dtype,
      )
      threshold = init_log_likelihood + tf.math.log(u)

      starting_angle = samplers.uniform(
          shape=tf.shape(init_log_likelihood),
          minval=0.,
          maxval=2 * np.pi,
          name='angle',
          seed=angle_seed,
          dtype=init_log_likelihood.dtype.base_dtype,
      )
      starting_angle_min = starting_angle - 2 * np.pi
      starting_angle_max = starting_angle

      starting_state_parts = _rotate_on_ellipse(
          init_state_parts, normal_samples, starting_angle)
      starting_log_likelihood = self.log_likelihood_fn(*starting_state_parts)  # pylint: disable=not-callable

      def chain_not_done(
          seed,
          angle,
          angle_min,
          angle_max,
          current_state_parts,
          current_log_likelihood):
        del seed, angle, angle_min, angle_max, current_state_parts
        return tf.reduce_any(current_log_likelihood < threshold)

      def sample_next_angle(
          seed,
          angle,
          angle_min,
          angle_max,
          current_state_parts,
          current_log_likelihood):
        """Slice sample a new angle, and rotate init_state by that amount."""
        angle_seed, next_seed = samplers.split_seed(seed)
        chain_not_done = current_log_likelihood < threshold
        # Box in on angle. Only update angles for which we haven't generated a
        # point that beats the threshold.
        angle_min = tf.where(
            (angle < 0) & chain_not_done,
            angle,
            angle_min)
        angle_max = tf.where(
            (angle >= 0) & chain_not_done,
            angle,
            angle_max)
        new_angle = samplers.uniform(
            shape=tf.shape(current_log_likelihood),
            minval=angle_min,
            maxval=angle_max,
            seed=angle_seed,
            dtype=angle.dtype.base_dtype
        )
        angle = tf.where(chain_not_done, new_angle, angle)
        next_state_parts = _rotate_on_ellipse(
            init_state_parts, normal_samples, angle)

        new_state_parts = []
        broadcasted_chain_not_done = _right_pad_with_ones(
            chain_not_done, tf.rank(next_state_parts[0]))
        for n_state, c_state in zip(next_state_parts, current_state_parts):
          new_state_part = tf.where(
              broadcasted_chain_not_done, n_state, c_state)
          new_state_parts.append(new_state_part)

        return (
            next_seed,
            angle,
            angle_min,
            angle_max,
            new_state_parts,
            self.log_likelihood_fn(*new_state_parts)  # pylint: disable=not-callable
        )

      [
          _,
          next_angle,
          _,
          _,
          next_state_parts,
          next_log_likelihood,
      ] = tf.while_loop(
          cond=chain_not_done,
          body=sample_next_angle,
          loop_vars=[
              loop_seed,
              starting_angle,
              starting_angle_min,
              starting_angle_max,
              starting_state_parts,
              starting_log_likelihood
          ])

      return [
          next_state_parts if mcmc_util.is_list_like(
              current_state) else next_state_parts[0],
          EllipticalSliceSamplerKernelResults(
              log_likelihood=next_log_likelihood,
              angle=next_angle,
              normal_samples=normal_samples,
              seed=seed,
          ),
      ]