in src/beanmachine/ppl/legacy/inference/proposer/single_site_no_u_turn_sampler_proposer.py [0:0]
def propose(self, node: RVIdentifier, world: World) -> Tuple[Tensor, Tensor, Dict]:
"""
Proposes a new value for the node.
:param node: the node for which we'll need to propose a new value for.
:param world: the world in which we'll propose a new value for node.
:returns: a new proposed value for the node and the difference in kinetic
energy between the start and the end value
"""
node_var = world.get_node_in_world_raise_error(node, False)
if node_var.value is None:
raise ValueError(f"{node} has no value")
theta = node_var.transformed_value
if not self.initialized:
self._find_reasonable_step_size(node, world)
self.mu = math.log(10 * self.step_size)
self.best_step_size = 1.0
self.initialized = True
r = self._initialize_momentum(theta)
u = (
dist.Uniform(tensor(0.0, dtype=theta.dtype, device=theta.device), 1.0)
.sample()
.log()
)
theta_n = theta
theta_p = theta
r_n = r
r_p = r
theta_propose = theta
n = tensor(1.0, dtype=theta.dtype)
s = tensor(1.0, dtype=theta.dtype)
for j in range(self.max_depth):
v = (
dist.Bernoulli(
tensor(0.5, dtype=theta.dtype, device=theta.device)
).sample()
) * 2 - 1
if v < 0:
build_tree_output = self._build_tree(
node, world, (theta_n, r_n, u, v, j, theta, r)
)
theta_n, r_n, _, _, theta1, n1, s1, a, na = build_tree_output
else:
build_tree_output = self._build_tree(
node, world, (theta_p, r_p, u, v, j, theta, r)
)
_, _, theta_p, r_p, theta1, n1, s1, a, na = build_tree_output
if torch.eq(s1, tensor(1.0, dtype=theta.dtype)):
change_val = dist.Bernoulli(
torch.min(
tensor(1.0, dtype=theta.dtype, device=theta.device), n1 / n
)
).sample()
if change_val:
theta_propose = theta1
n = n + n1
if torch.ne(s1, tensor(1.0, dtype=s1.dtype, device=s1.device)):
s = s1
else:
if self.use_dense_mass_matrix:
p_vector = torch.reshape(theta_p, (-1,))
if self.l_inv is None:
self.l_inv = torch.eye(
len(p_vector), dtype=theta.dtype, device=theta.device
)
transformed_p = torch.reshape(
torch.matmul(self.l_inv, p_vector), theta_p.shape
)
n_vector = torch.reshape(theta_p, (-1,))
transformed_n = torch.reshape(
torch.matmul(self.l_inv, n_vector), theta_n.shape
)
else:
transformed_p = theta_p
transformed_n = theta_n
turn_n = ((transformed_p - transformed_n) * r_n).sum() >= 0
turn_p = ((transformed_p - transformed_n) * r_p).sum() >= 0
s = turn_n * turn_p
if torch.ne(s, tensor(1.0, dtype=s.dtype, device=s.device)):
break
# need this for adaptive step (sometimes na is 0)
self.ratio = a.item() / max(na.item(), 1.0)
q = node_var.inverse_transform_value(theta_propose)
(children_log_update, _, node_log_update, _) = world.propose_change(
node, q, allow_graph_update=False
)
world.reset_diff()
# cancel out children and node log update because not needed for NUTS
return q.detach(), -children_log_update - node_log_update, {}