in pyro/infer/mcmc/nuts.py [0:0]
def sample(self, params):
z, potential_energy, z_grads = self._fetch_from_cache()
# recompute PE when cache is cleared
if z is None:
z = params
z_grads, potential_energy = potential_grad(self.potential_fn, z)
self._cache(z, potential_energy, z_grads)
# return early if no sample sites
elif len(z) == 0:
self._t += 1
self._mean_accept_prob = 1.
if self._t > self._warmup_steps:
self._accept_cnt += 1
return z
r, r_unscaled = self._sample_r(name="r_t={}".format(self._t))
energy_current = self._kinetic_energy(r_unscaled) + potential_energy
# Ideally, following a symplectic integrator trajectory, the energy is constant.
# In that case, we can sample the proposal uniformly, and there is no need to use "slice".
# However, it is not the case for real situation: there are errors during the computation.
# To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
# by u).
# The sampling process goes as follows:
# first sampling u from initial state (z_0, r_0) according to
# u ~ Uniform(0, p(z_0, r_0)),
# then sampling state (z, r) from the integrator trajectory according to
# (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
#
# For more information about slice sampling method, see [3].
# For another version of NUTS which uses multinomial sampling instead of slice sampling,
# see [2].
if self.use_multinomial_sampling:
log_slice = -energy_current
else:
# Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
# sample log_slice directly using `energy`, so as to avoid potential underflow or
# overflow issues ([2]).
slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t),
dist.Exponential(scalar_like(energy_current, 1.)))
log_slice = -energy_current - slice_exp_term
z_left = z_right = z
r_left = r_right = r
r_left_unscaled = r_right_unscaled = r_unscaled
z_left_grads = z_right_grads = z_grads
accepted = False
r_sum = r_unscaled
sum_accept_probs = 0.
num_proposals = 0
tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.)
# Temporarily disable distributions args checking as
# NaNs are expected during step size adaptation.
with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
# doubling process, stop when turning or diverging
tree_depth = 0
while tree_depth < self._max_tree_depth:
direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth),
dist.Bernoulli(probs=scalar_like(tree_weight, 0.5)))
direction = int(direction.item())
if direction == 1: # go to the right, start from the right leaf of current tree
new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice,
direction, tree_depth, energy_current)
# update leaf for the next doubling process
z_right = new_tree.z_right
r_right = new_tree.r_right
r_right_unscaled = new_tree.r_right_unscaled
z_right_grads = new_tree.z_right_grads
else: # go the the left, start from the left leaf of current tree
new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice,
direction, tree_depth, energy_current)
z_left = new_tree.z_left
r_left = new_tree.r_left
r_left_unscaled = new_tree.r_left_unscaled
z_left_grads = new_tree.z_left_grads
sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs
num_proposals = num_proposals + new_tree.num_proposals
# stop doubling
if new_tree.diverging:
if self._t >= self._warmup_steps:
self._divergences.append(self._t - self._warmup_steps)
break
if new_tree.turning:
break
tree_depth += 1
if self.use_multinomial_sampling:
new_tree_prob = (new_tree.weight - tree_weight).exp()
else:
new_tree_prob = new_tree.weight / tree_weight
rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth),
dist.Uniform(scalar_like(new_tree_prob, 0.),
scalar_like(new_tree_prob, 1.)))
if rand < new_tree_prob:
accepted = True
z = new_tree.z_proposal
z_grads = new_tree.z_proposal_grads
self._cache(z, new_tree.z_proposal_pe, z_grads)
r_sum = {site_names: r_sum[site_names] + new_tree.r_sum[site_names]
for site_names in r_unscaled}
if self._is_turning(r_left_unscaled, r_right_unscaled, r_sum): # stop doubling
break
else: # update tree_weight
if self.use_multinomial_sampling:
tree_weight = _logaddexp(tree_weight, new_tree.weight)
else:
tree_weight = tree_weight + new_tree.weight
accept_prob = sum_accept_probs / num_proposals
self._t += 1
if self._t > self._warmup_steps:
n = self._t - self._warmup_steps
if accepted:
self._accept_cnt += 1
else:
n = self._t
self._adapter.step(self._t, z, accept_prob, z_grads)
self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n
return z.copy()