in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/model.py [0:0]
def forward(self, g, user_features, website_features):
# get embeddings for all node types. for user node type, use passed in user features
h_dict = {}
h_dict['user'] = nn.Parameter(user_features)
h_dict['website'] = nn.Parameter(website_features)
for ntype in self.embed:
if g[0].number_of_nodes(ntype) > 0:
h_dict[ntype] = self.embed[ntype][g[0].nodes(ntype).long(), :]
# pass through all layers
for i, layer in enumerate(self.layers):
if i != 0:
h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()}
h_dict = layer(g[i], h_dict)
# get user logits
return h_dict['user']