in src/lic/ppl/world/world.py [0:0]
def update_graph(self, node: RVIdentifier) -> Tensor:
"""
Updates the parents and children of the node based on the stack
:param node: the node which was called from StatisticalModel.random_variable()
"""
assert self.maintain_graph_ != self.cache_functionals_
if len(self.stack_) > 0:
# We are making updates to the parent, so we need to call the
# get_node_in_world_raise_error, we don't need to add the variable
# to the latest diff because it's been already added there given
# that it's in the stack.
self.get_node_in_world_raise_error(self.stack_[-1]).parent.add(node)
# We are adding the diff manually to the latest diff manually in line
# 509 and 527.
node_var = self.get_node_in_world(node, False)
if node_var is not None:
if (
self.maintain_graph_
and len(self.stack_) > 0
and self.stack_[-1] not in node_var.children
):
var_copy = node_var.copy()
var_copy.children.add(self.stack_[-1])
self.add_node_to_world(node, var_copy)
return node_var.value
node_var = Variable(
# pyre-fixme
distribution=None,
value=None,
log_prob=None,
parent=set(),
children=set() if len(self.stack_) == 0 else set({self.stack_[-1]}),
proposal_distribution=None,
is_discrete=None,
transforms=None,
transformed_value=None,
jacobian=None,
)
self.add_node_to_world(node, node_var)
self.stack_.append(node)
with self:
node_var.distribution = node.function(*node.arguments)
self.stack_.pop()
obs_value = self.observations_[node] if node in self.observations_ else None
value = None
if self.vi_dicts and not obs_value:
# resample latents from q
value = self.vi_dicts[node].rsample((1,))
node_var.update_fields(
value,
obs_value,
self.get_transforms_for_node(node),
self.get_proposer_for_node(node),
self.initialize_from_prior_,
)
if self.maintain_graph_:
self.update_diff_log_prob(node)
return node_var.value