in src/lic/ppl/experimental/inference_compilation/ic_infer.py [0:0]
def _proposer_func_for_node(self, node: RVIdentifier):
_, proposal_dist_constructor = self._proposal_distribution_for_node(node)
def _proposer_func(
world: World,
markov_blanket: Iterable[RVIdentifier],
observations: Dict[RVIdentifier, Tensor],
) -> dist.Distribution:
obs_embedding = torch.zeros(self._OBS_EMBEDDING_DIM)
obs_nodes = list(
map(
lambda x: x[1],
sorted(observations.items(), key=lambda x: str(x[0])),
)
)
if False:
# if len(obs_nodes):
obs_embedding_net = self._obs_embedding_net
if obs_embedding_net is None:
raise Exception("No observation embedding network found!")
obs_vec = torch.stack(obs_nodes, dim=0).flatten()
# pyre-fixme
obs_embedding = obs_embedding_net.forward(obs_vec)
node_embedding_nets = self._node_embedding_nets
if node_embedding_nets is None:
raise Exception("No node embedding networks found!")
mb_embedding = torch.zeros(self._MB_EMBEDDING_DIM)
mb_nodes = list(
map(
lambda mb_node: node_embedding_nets(mb_node).forward(
utils.ensure_1d(
world.get_node_in_world_raise_error(mb_node).value
)
),
sorted(markov_blanket, key=str),
)
)
if len(mb_nodes) and self._MB_EMBEDDING_DIM > 0:
mb_embedding_nets = self._mb_embedding_nets
if mb_embedding_nets is None:
raise Exception("No Markov blanket embedding networks found!")
# NOTE: currently adds batch axis (at index 1) here, may need
# to change when we batch training (see
# torch.nn.utils.rnn.PackedSequence)
mb_vec = torch.stack(mb_nodes, dim=0).unsqueeze(1)
# TODO: try pooling rather than just slicing out last hidden
mb_embedding = utils.ensure_1d(
mb_embedding_nets(node).forward(mb_vec)[0][-1, :, :].squeeze()
)
node_proposal_param_nets = self._node_proposal_param_nets
if node_proposal_param_nets is None:
raise Exception("No node proposal parameter networks found!")
param_vec = node_proposal_param_nets(node).forward(
torch.cat((mb_embedding, obs_embedding))
)
return proposal_dist_constructor(param_vec)
return _proposer_func