in src/beanmachine/ppl/legacy/world/world.py [0:0]
def update_graph(self, node: RVIdentifier) -> Tensor:
"""
Updates the parents and children of the node based on the stack
:param node: the node which was called from StatisticalModel.random_variable()
"""
assert self.maintain_graph_ != self.cache_functionals_
if len(self.stack_) > 0:
# We are making updates to the parent, so we need to call the
# get_node_in_world_raise_error, we don't need to add the variable
# to the latest diff because it's been already added there given
# that it's in the stack.
self.get_node_in_world_raise_error(self.stack_[-1]).parent.add(node)
# We are adding the diff manually to the latest diff manually in line
# 509 and 527.
node_var = self.get_node_in_world(node, False)
if node_var is not None:
if (
self.maintain_graph_
and len(self.stack_) > 0
and self.stack_[-1] not in node_var.children
):
var_copy = node_var.copy()
var_copy.children.add(self.stack_[-1])
self.add_node_to_world(node, var_copy)
return node_var.value
node_var = Variable(
# pyre-fixme
distribution=None,
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got `None`.
value=None,
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got `None`.
log_prob=None,
children=set() if len(self.stack_) == 0 else set({self.stack_[-1]}),
# pyre-fixme[6]: Expected `Transform` for 5th param but got `None`.
transform=None,
# pyre-fixme[6]: Expected `Tensor` for 6th param but got `None`.
transformed_value=None,
# pyre-fixme[6]: Expected `Tensor` for 7th param but got `None`.
jacobian=None,
)
self.add_node_to_world(node, node_var)
self.stack_.append(node)
with self:
d = node.function(*node.arguments)
if not isinstance(d, Distribution):
raise TypeError(
"A random_variable is required to return a distribution."
)
node_var.distribution = d
self.stack_.pop()
obs_value = self.observations_.get(node)
# resample latents from q
value = None
vi_dicts = self.vi_dicts
model_to_guide_ids = self.model_to_guide_ids_
if obs_value is None:
# TODO: messy, consider strategy pattern
if vi_dicts is not None:
# mean-field VI
variational_approx = vi_dicts(node)
value = variational_approx.rsample((1,)).squeeze()
elif (
isinstance(model_to_guide_ids, dict)
and node not in model_to_guide_ids.values() # is not a model RV
):
# guide-based VI on non-guide nodes only
assert (
node in model_to_guide_ids
), f"Could not find a guide for {node}. VariationalInference requires every latent variable in the model to have a corresponding guide."
guide_node = model_to_guide_ids[node]
guide_var = self.get_node_in_world(guide_node)
if not guide_var:
# initialize guide node if missing
self.call(guide_node)
guide_var = self.get_node_in_world_raise_error(guide_node)
try:
value = guide_var.distribution.rsample(torch.Size((1,)))
except NotImplementedError:
value = guide_var.distribution.sample(torch.Size((1,)))
node_var.update_fields(
value,
obs_value,
self.get_transforms_for_node(node),
self.get_proposer_for_node(node),
self.initialize_from_prior_,
)
if self.maintain_graph_:
self.update_diff_log_prob(node)
return node_var.value