def update_graph()

in src/beanmachine/ppl/legacy/world/world.py [0:0]


    def update_graph(self, node: RVIdentifier) -> Tensor:
        """
        Updates the parents and children of the node based on the stack

        :param node: the node which was called from StatisticalModel.random_variable()
        """
        assert self.maintain_graph_ != self.cache_functionals_
        if len(self.stack_) > 0:
            # We are making updates to the parent, so we need to call the
            # get_node_in_world_raise_error, we don't need to add the variable
            # to the latest diff because it's been already added there given
            # that it's in the stack.
            self.get_node_in_world_raise_error(self.stack_[-1]).parent.add(node)

        # We are adding the diff manually to the latest diff manually in line
        # 509 and 527.
        node_var = self.get_node_in_world(node, False)
        if node_var is not None:
            if (
                self.maintain_graph_
                and len(self.stack_) > 0
                and self.stack_[-1] not in node_var.children
            ):
                var_copy = node_var.copy()
                var_copy.children.add(self.stack_[-1])
                self.add_node_to_world(node, var_copy)
            return node_var.value

        node_var = Variable(
            # pyre-fixme
            distribution=None,
            # pyre-fixme[6]: Expected `Tensor` for 2nd param but got `None`.
            value=None,
            # pyre-fixme[6]: Expected `Tensor` for 3rd param but got `None`.
            log_prob=None,
            children=set() if len(self.stack_) == 0 else set({self.stack_[-1]}),
            # pyre-fixme[6]: Expected `Transform` for 5th param but got `None`.
            transform=None,
            # pyre-fixme[6]: Expected `Tensor` for 6th param but got `None`.
            transformed_value=None,
            # pyre-fixme[6]: Expected `Tensor` for 7th param but got `None`.
            jacobian=None,
        )

        self.add_node_to_world(node, node_var)
        self.stack_.append(node)
        with self:
            d = node.function(*node.arguments)
            if not isinstance(d, Distribution):
                raise TypeError(
                    "A random_variable is required to return a distribution."
                )
            node_var.distribution = d
        self.stack_.pop()

        obs_value = self.observations_.get(node)

        # resample latents from q
        value = None
        vi_dicts = self.vi_dicts
        model_to_guide_ids = self.model_to_guide_ids_
        if obs_value is None:
            # TODO: messy, consider strategy pattern
            if vi_dicts is not None:
                # mean-field VI
                variational_approx = vi_dicts(node)
                value = variational_approx.rsample((1,)).squeeze()
            elif (
                isinstance(model_to_guide_ids, dict)
                and node not in model_to_guide_ids.values()  # is not a model RV
            ):
                # guide-based VI on non-guide nodes only
                assert (
                    node in model_to_guide_ids
                ), f"Could not find a guide for {node}. VariationalInference requires every latent variable in the model to have a corresponding guide."
                guide_node = model_to_guide_ids[node]
                guide_var = self.get_node_in_world(guide_node)
                if not guide_var:
                    # initialize guide node if missing
                    self.call(guide_node)
                guide_var = self.get_node_in_world_raise_error(guide_node)
                try:
                    value = guide_var.distribution.rsample(torch.Size((1,)))
                except NotImplementedError:
                    value = guide_var.distribution.sample(torch.Size((1,)))

        node_var.update_fields(
            value,
            obs_value,
            self.get_transforms_for_node(node),
            self.get_proposer_for_node(node),
            self.initialize_from_prior_,
        )

        if self.maintain_graph_:
            self.update_diff_log_prob(node)

        return node_var.value