def sample()

in pyro/infer/mcmc/nuts.py [0:0]


    def sample(self, params):
        z, potential_energy, z_grads = self._fetch_from_cache()
        # recompute PE when cache is cleared
        if z is None:
            z = params
            z_grads, potential_energy = potential_grad(self.potential_fn, z)
            self._cache(z, potential_energy, z_grads)
        # return early if no sample sites
        elif len(z) == 0:
            self._t += 1
            self._mean_accept_prob = 1.
            if self._t > self._warmup_steps:
                self._accept_cnt += 1
            return z
        r, r_unscaled = self._sample_r(name="r_t={}".format(self._t))
        energy_current = self._kinetic_energy(r_unscaled) + potential_energy

        # Ideally, following a symplectic integrator trajectory, the energy is constant.
        # In that case, we can sample the proposal uniformly, and there is no need to use "slice".
        # However, it is not the case for real situation: there are errors during the computation.
        # To deal with that problem, as in [1], we introduce an auxiliary "slice" variable (denoted
        # by u).
        # The sampling process goes as follows:
        #   first sampling u from initial state (z_0, r_0) according to
        #     u ~ Uniform(0, p(z_0, r_0)),
        #   then sampling state (z, r) from the integrator trajectory according to
        #     (z, r) ~ Uniform({(z', r') in trajectory | p(z', r') >= u}).
        #
        # For more information about slice sampling method, see [3].
        # For another version of NUTS which uses multinomial sampling instead of slice sampling,
        # see [2].

        if self.use_multinomial_sampling:
            log_slice = -energy_current
        else:
            # Rather than sampling the slice variable from `Uniform(0, exp(-energy))`, we can
            # sample log_slice directly using `energy`, so as to avoid potential underflow or
            # overflow issues ([2]).
            slice_exp_term = pyro.sample("slicevar_exp_t={}".format(self._t),
                                         dist.Exponential(scalar_like(energy_current, 1.)))
            log_slice = -energy_current - slice_exp_term

        z_left = z_right = z
        r_left = r_right = r
        r_left_unscaled = r_right_unscaled = r_unscaled
        z_left_grads = z_right_grads = z_grads
        accepted = False
        r_sum = r_unscaled
        sum_accept_probs = 0.
        num_proposals = 0
        tree_weight = scalar_like(energy_current, 0. if self.use_multinomial_sampling else 1.)

        # Temporarily disable distributions args checking as
        # NaNs are expected during step size adaptation.
        with optional(pyro.validation_enabled(False), self._t < self._warmup_steps):
            # doubling process, stop when turning or diverging
            tree_depth = 0
            while tree_depth < self._max_tree_depth:
                direction = pyro.sample("direction_t={}_treedepth={}".format(self._t, tree_depth),
                                        dist.Bernoulli(probs=scalar_like(tree_weight, 0.5)))
                direction = int(direction.item())
                if direction == 1:  # go to the right, start from the right leaf of current tree
                    new_tree = self._build_tree(z_right, r_right, z_right_grads, log_slice,
                                                direction, tree_depth, energy_current)
                    # update leaf for the next doubling process
                    z_right = new_tree.z_right
                    r_right = new_tree.r_right
                    r_right_unscaled = new_tree.r_right_unscaled
                    z_right_grads = new_tree.z_right_grads
                else:  # go the the left, start from the left leaf of current tree
                    new_tree = self._build_tree(z_left, r_left, z_left_grads, log_slice,
                                                direction, tree_depth, energy_current)
                    z_left = new_tree.z_left
                    r_left = new_tree.r_left
                    r_left_unscaled = new_tree.r_left_unscaled
                    z_left_grads = new_tree.z_left_grads

                sum_accept_probs = sum_accept_probs + new_tree.sum_accept_probs
                num_proposals = num_proposals + new_tree.num_proposals

                # stop doubling
                if new_tree.diverging:
                    if self._t >= self._warmup_steps:
                        self._divergences.append(self._t - self._warmup_steps)
                    break

                if new_tree.turning:
                    break

                tree_depth += 1

                if self.use_multinomial_sampling:
                    new_tree_prob = (new_tree.weight - tree_weight).exp()
                else:
                    new_tree_prob = new_tree.weight / tree_weight
                rand = pyro.sample("rand_t={}_treedepth={}".format(self._t, tree_depth),
                                   dist.Uniform(scalar_like(new_tree_prob, 0.),
                                                scalar_like(new_tree_prob, 1.)))
                if rand < new_tree_prob:
                    accepted = True
                    z = new_tree.z_proposal
                    z_grads = new_tree.z_proposal_grads
                    self._cache(z, new_tree.z_proposal_pe, z_grads)

                r_sum = {site_names: r_sum[site_names] + new_tree.r_sum[site_names]
                         for site_names in r_unscaled}
                if self._is_turning(r_left_unscaled, r_right_unscaled, r_sum):  # stop doubling
                    break
                else:  # update tree_weight
                    if self.use_multinomial_sampling:
                        tree_weight = _logaddexp(tree_weight, new_tree.weight)
                    else:
                        tree_weight = tree_weight + new_tree.weight

        accept_prob = sum_accept_probs / num_proposals

        self._t += 1
        if self._t > self._warmup_steps:
            n = self._t - self._warmup_steps
            if accepted:
                self._accept_cnt += 1
        else:
            n = self._t
            self._adapter.step(self._t, z, accept_prob, z_grads)
        self._mean_accept_prob += (accept_prob.item() - self._mean_accept_prob) / n

        return z.copy()