in src/lic/ppl/experimental/inference_compilation/ic_infer.py [0:0]
def _build_node_embedding_network(self, node: RVIdentifier) -> nn.Module:
node_var = self.world_.get_node_in_world_raise_error(node)
node_vec = utils.ensure_1d(node_var.value)
# NOTE: assumes that node does not change shape across worlds
node_embedding_net = nn.Sequential(
nn.Linear(
in_features=node_vec.shape[0], out_features=self._NODE_EMBEDDING_DIM
)
)
node_id_embedding = torch.randn(self._NODE_ID_EMBEDDING_DIM)
node_id_embedding /= node_id_embedding.norm(p=2)
class NodeEmbedding(nn.Module):
"""
Node embedding network which concatenates one-hot encoding of
node ID with node value embedding.
"""
def __init__(self, node_id_embedding, embedding_net):
super().__init__()
self.node_id_embedding = node_id_embedding
self.embedding_net = embedding_net
def forward(self, x):
return torch.cat(
(self.node_id_embedding, self.embedding_net.forward(x.float()))
)
return NodeEmbedding(node_id_embedding, node_embedding_net)