in src/lic/ppl/inference/abstract_mh_infer.py [0:0]
def block_propose_change(self, block: Block) -> Tuple[Tensor, Tensor, Tensor]:
"""
Propose values for a block of random variable
:param block: the block to propose new value for. A block is a group of
random variable which we will sequentially update and accept their
values all-together.
:param world: the world in which a new value for block is going to be
proposed.
:returns: nodes_log_updates, children_log_updates and
proposal_log_updates of the values proposed for the block.
"""
markov_blanket = set({block.first_node})
markov_blanket_func = {}
markov_blanket_func[get_wrapper(block.first_node.function)] = [block.first_node]
pos_proposal_log_updates, neg_proposal_log_updates = tensor(0.0), tensor(0.0)
children_log_updates, nodes_log_updates = tensor(0.0), tensor(0.0)
# We will go through all family of random variable in the block. Note
# that in block we have family of X and not the specific random variable
# X(1)
for node_func in block.block:
# We then look up which of the random variable in the family are in
# the markov blanket
if node_func not in markov_blanket_func:
continue
# We will go through all random variables that are both in the
# markov blanket and block.
for node in markov_blanket_func[node_func].copy():
if self.world_.is_marked_node_for_delete(node):
continue
# We look up the node's current markov blanket before re-sampling
old_node_markov_blanket = (
self.world_.get_markov_blanket(node) - self.observations_.keys()
)
proposer = self.find_best_single_site_proposer(node)
LOGGER_PROPOSER.log(
LogLevel.DEBUG_PROPOSER.value,
"=" * 30
+ "\n"
+ "Proposer info for node: {n}\n".format(n=node)
+ "- Type: {pt}\n".format(pt=str(type(proposer))),
)
# We use the best single site proposer to propose a new value.
(
proposed_value,
negative_proposal_log_update,
auxiliary_variables,
) = proposer.propose(node, self.world_)
neg_proposal_log_updates += negative_proposal_log_update
LOGGER_INFERENCE.log(
LogLevel.DEBUG_UPDATES.value,
"Node: {n}\n".format(n=node)
+ "- Node value: {nv}\n".format(
# pyre-fixme
nv=self.world_.get_node_in_world(node, False, False).value
)
+ "- Proposed value: {pv}\n".format(pv=proposed_value),
)
# We update the world (through a new diff in the diff stack).
children_log_update, _, node_log_update, _ = self.world_.propose_change(
node, proposed_value, start_new_diff=True
)
children_log_updates += children_log_update
nodes_log_updates += node_log_update
pos_proposal_log_updates += proposer.post_process(
node, self.world_, auxiliary_variables
)
# We look up the updated markov blanket of the re-sampled node.
new_node_markov_blanket = (
self.world_.get_markov_blanket(node) - self.observations_.keys()
)
all_node_markov_blanket = (
old_node_markov_blanket | new_node_markov_blanket
)
# new_nodes_to_be_added is all the new nodes to be added to
# entire markov blanket.
new_nodes_to_be_added = all_node_markov_blanket - markov_blanket
for new_node in new_nodes_to_be_added:
if new_node is None:
continue
# We create a dictionary from node family to the node itself
# as the match with block happens at the family level and
# this makes the lookup much faster.
if get_wrapper(new_node.function) not in markov_blanket_func:
markov_blanket_func[get_wrapper(new_node.function)] = []
markov_blanket_func[get_wrapper(new_node.function)].append(new_node)
markov_blanket |= new_nodes_to_be_added
proposal_log_updates = pos_proposal_log_updates + neg_proposal_log_updates
return nodes_log_updates, children_log_updates, proposal_log_updates