def _proposer_func_for_node()

in src/lic/ppl/experimental/inference_compilation/ic_infer.py [0:0]


    def _proposer_func_for_node(self, node: RVIdentifier):
        _, proposal_dist_constructor = self._proposal_distribution_for_node(node)

        def _proposer_func(
            world: World,
            markov_blanket: Iterable[RVIdentifier],
            observations: Dict[RVIdentifier, Tensor],
        ) -> dist.Distribution:
            obs_embedding = torch.zeros(self._OBS_EMBEDDING_DIM)
            obs_nodes = list(
                map(
                    lambda x: x[1],
                    sorted(observations.items(), key=lambda x: str(x[0])),
                )
            )
            if False:
                # if len(obs_nodes):
                obs_embedding_net = self._obs_embedding_net
                if obs_embedding_net is None:
                    raise Exception("No observation embedding network found!")

                obs_vec = torch.stack(obs_nodes, dim=0).flatten()
                # pyre-fixme
                obs_embedding = obs_embedding_net.forward(obs_vec)

            node_embedding_nets = self._node_embedding_nets
            if node_embedding_nets is None:
                raise Exception("No node embedding networks found!")

            mb_embedding = torch.zeros(self._MB_EMBEDDING_DIM)
            mb_nodes = list(
                map(
                    lambda mb_node: node_embedding_nets(mb_node).forward(
                        utils.ensure_1d(
                            world.get_node_in_world_raise_error(mb_node).value
                        )
                    ),
                    sorted(markov_blanket, key=str),
                )
            )
            if len(mb_nodes) and self._MB_EMBEDDING_DIM > 0:
                mb_embedding_nets = self._mb_embedding_nets
                if mb_embedding_nets is None:
                    raise Exception("No Markov blanket embedding networks found!")

                # NOTE: currently adds batch axis (at index 1) here, may need
                # to change when we batch training (see
                # torch.nn.utils.rnn.PackedSequence)
                mb_vec = torch.stack(mb_nodes, dim=0).unsqueeze(1)
                # TODO: try pooling rather than just slicing out last hidden
                mb_embedding = utils.ensure_1d(
                    mb_embedding_nets(node).forward(mb_vec)[0][-1, :, :].squeeze()
                )
            node_proposal_param_nets = self._node_proposal_param_nets
            if node_proposal_param_nets is None:
                raise Exception("No node proposal parameter networks found!")
            param_vec = node_proposal_param_nets(node).forward(
                torch.cat((mb_embedding, obs_embedding))
            )
            return proposal_dist_constructor(param_vec)

        return _proposer_func