def update_graph()

in src/lic/ppl/world/world.py [0:0]


    def update_graph(self, node: RVIdentifier) -> Tensor:
        """
        Updates the parents and children of the node based on the stack

        :param node: the node which was called from StatisticalModel.random_variable()
        """
        assert self.maintain_graph_ != self.cache_functionals_
        if len(self.stack_) > 0:
            # We are making updates to the parent, so we need to call the
            # get_node_in_world_raise_error, we don't need to add the variable
            # to the latest diff because it's been already added there given
            # that it's in the stack.
            self.get_node_in_world_raise_error(self.stack_[-1]).parent.add(node)

        # We are adding the diff manually to the latest diff manually in line
        # 509 and 527.
        node_var = self.get_node_in_world(node, False)
        if node_var is not None:
            if (
                self.maintain_graph_
                and len(self.stack_) > 0
                and self.stack_[-1] not in node_var.children
            ):
                var_copy = node_var.copy()
                var_copy.children.add(self.stack_[-1])
                self.add_node_to_world(node, var_copy)
            return node_var.value

        node_var = Variable(
            # pyre-fixme
            distribution=None,
            value=None,
            log_prob=None,
            parent=set(),
            children=set() if len(self.stack_) == 0 else set({self.stack_[-1]}),
            proposal_distribution=None,
            is_discrete=None,
            transforms=None,
            transformed_value=None,
            jacobian=None,
        )

        self.add_node_to_world(node, node_var)
        self.stack_.append(node)
        with self:
            node_var.distribution = node.function(*node.arguments)
        self.stack_.pop()

        obs_value = self.observations_[node] if node in self.observations_ else None

        value = None
        if self.vi_dicts and not obs_value:
            # resample latents from q
            value = self.vi_dicts[node].rsample((1,))

        node_var.update_fields(
            value,
            obs_value,
            self.get_transforms_for_node(node),
            self.get_proposer_for_node(node),
            self.initialize_from_prior_,
        )

        if self.maintain_graph_:
            self.update_diff_log_prob(node)

        return node_var.value