def _build_node_embedding_network()

in src/lic/ppl/experimental/inference_compilation/ic_infer.py [0:0]


    def _build_node_embedding_network(self, node: RVIdentifier) -> nn.Module:
        node_var = self.world_.get_node_in_world_raise_error(node)
        node_vec = utils.ensure_1d(node_var.value)
        # NOTE: assumes that node does not change shape across worlds
        node_embedding_net = nn.Sequential(
            nn.Linear(
                in_features=node_vec.shape[0], out_features=self._NODE_EMBEDDING_DIM
            )
        )
        node_id_embedding = torch.randn(self._NODE_ID_EMBEDDING_DIM)
        node_id_embedding /= node_id_embedding.norm(p=2)

        class NodeEmbedding(nn.Module):
            """
            Node embedding network which concatenates one-hot encoding of
            node ID with node value embedding.
            """

            def __init__(self, node_id_embedding, embedding_net):
                super().__init__()
                self.node_id_embedding = node_id_embedding
                self.embedding_net = embedding_net

            def forward(self, x):
                return torch.cat(
                    (self.node_id_embedding, self.embedding_net.forward(x.float()))
                )

        return NodeEmbedding(node_id_embedding, node_embedding_net)