def propose()

in src/beanmachine/ppl/legacy/inference/proposer/single_site_no_u_turn_sampler_proposer.py [0:0]


    def propose(self, node: RVIdentifier, world: World) -> Tuple[Tensor, Tensor, Dict]:
        """
        Proposes a new value for the node.

        :param node: the node for which we'll need to propose a new value for.
        :param world: the world in which we'll propose a new value for node.
        :returns: a new proposed value for the node and the difference in kinetic
        energy between the start and the end value
        """
        node_var = world.get_node_in_world_raise_error(node, False)
        if node_var.value is None:
            raise ValueError(f"{node} has no value")
        theta = node_var.transformed_value

        if not self.initialized:
            self._find_reasonable_step_size(node, world)
            self.mu = math.log(10 * self.step_size)
            self.best_step_size = 1.0
            self.initialized = True
        r = self._initialize_momentum(theta)

        u = (
            dist.Uniform(tensor(0.0, dtype=theta.dtype, device=theta.device), 1.0)
            .sample()
            .log()
        )

        theta_n = theta
        theta_p = theta
        r_n = r
        r_p = r
        theta_propose = theta
        n = tensor(1.0, dtype=theta.dtype)
        s = tensor(1.0, dtype=theta.dtype)

        for j in range(self.max_depth):
            v = (
                dist.Bernoulli(
                    tensor(0.5, dtype=theta.dtype, device=theta.device)
                ).sample()
            ) * 2 - 1
            if v < 0:
                build_tree_output = self._build_tree(
                    node, world, (theta_n, r_n, u, v, j, theta, r)
                )
                theta_n, r_n, _, _, theta1, n1, s1, a, na = build_tree_output
            else:
                build_tree_output = self._build_tree(
                    node, world, (theta_p, r_p, u, v, j, theta, r)
                )
                _, _, theta_p, r_p, theta1, n1, s1, a, na = build_tree_output

            if torch.eq(s1, tensor(1.0, dtype=theta.dtype)):
                change_val = dist.Bernoulli(
                    torch.min(
                        tensor(1.0, dtype=theta.dtype, device=theta.device), n1 / n
                    )
                ).sample()
                if change_val:
                    theta_propose = theta1
            n = n + n1

            if torch.ne(s1, tensor(1.0, dtype=s1.dtype, device=s1.device)):
                s = s1
            else:
                if self.use_dense_mass_matrix:
                    p_vector = torch.reshape(theta_p, (-1,))
                    if self.l_inv is None:
                        self.l_inv = torch.eye(
                            len(p_vector), dtype=theta.dtype, device=theta.device
                        )
                    transformed_p = torch.reshape(
                        torch.matmul(self.l_inv, p_vector), theta_p.shape
                    )
                    n_vector = torch.reshape(theta_p, (-1,))
                    transformed_n = torch.reshape(
                        torch.matmul(self.l_inv, n_vector), theta_n.shape
                    )
                else:
                    transformed_p = theta_p
                    transformed_n = theta_n

                turn_n = ((transformed_p - transformed_n) * r_n).sum() >= 0
                turn_p = ((transformed_p - transformed_n) * r_p).sum() >= 0
                s = turn_n * turn_p

            if torch.ne(s, tensor(1.0, dtype=s.dtype, device=s.device)):
                break

        # need this for adaptive step (sometimes na is 0)
        self.ratio = a.item() / max(na.item(), 1.0)

        q = node_var.inverse_transform_value(theta_propose)
        (children_log_update, _, node_log_update, _) = world.propose_change(
            node, q, allow_graph_update=False
        )
        world.reset_diff()

        # cancel out children and node log update because not needed for NUTS
        return q.detach(), -children_log_update - node_log_update, {}