in tensorflow_probability/python/experimental/mcmc/elliptical_slice_sampler.py [0:0]
def one_step(self, current_state, previous_kernel_results, seed=None):
"""Runs one iteration of the Elliptical Slice Sampler.
Args:
current_state: `Tensor` or Python `list` of `Tensor`s representing the
current state(s) of the Markov chain(s). The first `r` dimensions
index independent chains,
`r = tf.rank(log_likelihood_fn(*normal_sampler_fn()))`.
previous_kernel_results: `collections.namedtuple` containing `Tensor`s
representing values from previous calls to this function (or from the
`bootstrap_results` function.)
seed: PRNG seed; see `tfp.random.sanitize_seed` for details.
Returns:
next_state: Tensor or Python list of `Tensor`s representing the state(s)
of the Markov chain(s) after taking exactly one step. Has same type and
shape as `current_state`.
kernel_results: `collections.namedtuple` of internal calculations used to
advance the chain.
Raises:
TypeError: if `not log_likelihood.dtype.is_floating`.
"""
with tf.name_scope(
mcmc_util.make_name(self.name, 'elliptical_slice', 'one_step')):
with tf.name_scope('initialize'):
[
init_state_parts,
init_log_likelihood
] = _prepare_args(
self.log_likelihood_fn,
current_state,
previous_kernel_results.log_likelihood)
seed = samplers.sanitize_seed(seed) # Unsalted, for kernel results.
normal_seed, u_seed, angle_seed, loop_seed = samplers.split_seed(
seed, n=4, salt='elliptical_slice_sampler')
normal_samples = self.normal_sampler_fn(normal_seed) # pylint: disable=not-callable
normal_samples = list(normal_samples) if mcmc_util.is_list_like(
normal_samples) else [normal_samples]
u = samplers.uniform(
shape=tf.shape(init_log_likelihood),
seed=u_seed,
dtype=init_log_likelihood.dtype.base_dtype,
)
threshold = init_log_likelihood + tf.math.log(u)
starting_angle = samplers.uniform(
shape=tf.shape(init_log_likelihood),
minval=0.,
maxval=2 * np.pi,
name='angle',
seed=angle_seed,
dtype=init_log_likelihood.dtype.base_dtype,
)
starting_angle_min = starting_angle - 2 * np.pi
starting_angle_max = starting_angle
starting_state_parts = _rotate_on_ellipse(
init_state_parts, normal_samples, starting_angle)
starting_log_likelihood = self.log_likelihood_fn(*starting_state_parts) # pylint: disable=not-callable
def chain_not_done(
seed,
angle,
angle_min,
angle_max,
current_state_parts,
current_log_likelihood):
del seed, angle, angle_min, angle_max, current_state_parts
return tf.reduce_any(current_log_likelihood < threshold)
def sample_next_angle(
seed,
angle,
angle_min,
angle_max,
current_state_parts,
current_log_likelihood):
"""Slice sample a new angle, and rotate init_state by that amount."""
angle_seed, next_seed = samplers.split_seed(seed)
chain_not_done = current_log_likelihood < threshold
# Box in on angle. Only update angles for which we haven't generated a
# point that beats the threshold.
angle_min = tf.where(
(angle < 0) & chain_not_done,
angle,
angle_min)
angle_max = tf.where(
(angle >= 0) & chain_not_done,
angle,
angle_max)
new_angle = samplers.uniform(
shape=tf.shape(current_log_likelihood),
minval=angle_min,
maxval=angle_max,
seed=angle_seed,
dtype=angle.dtype.base_dtype
)
angle = tf.where(chain_not_done, new_angle, angle)
next_state_parts = _rotate_on_ellipse(
init_state_parts, normal_samples, angle)
new_state_parts = []
broadcasted_chain_not_done = _right_pad_with_ones(
chain_not_done, tf.rank(next_state_parts[0]))
for n_state, c_state in zip(next_state_parts, current_state_parts):
new_state_part = tf.where(
broadcasted_chain_not_done, n_state, c_state)
new_state_parts.append(new_state_part)
return (
next_seed,
angle,
angle_min,
angle_max,
new_state_parts,
self.log_likelihood_fn(*new_state_parts) # pylint: disable=not-callable
)
[
_,
next_angle,
_,
_,
next_state_parts,
next_log_likelihood,
] = tf.while_loop(
cond=chain_not_done,
body=sample_next_angle,
loop_vars=[
loop_seed,
starting_angle,
starting_angle_min,
starting_angle_max,
starting_state_parts,
starting_log_likelihood
])
return [
next_state_parts if mcmc_util.is_list_like(
current_state) else next_state_parts[0],
EllipticalSliceSamplerKernelResults(
log_likelihood=next_log_likelihood,
angle=next_angle,
normal_samples=normal_samples,
seed=seed,
),
]