in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/model.py [0:0]
def inference(self, g, user_features, web_features, batch_size, n_neighbors, device, num_workers=0):
for l, layer in enumerate(self.rgcn.layers):
sampler = dgl.dataloading.MultiLayerNeighborSampler([n_neighbors])
dataloader = dgl.dataloading.NodeDataLoader(
g,
{ntype: torch.arange(g.number_of_nodes(ntype)) for ntype in g.ntypes},
sampler,
batch_size= batch_size,
shuffle=True,
drop_last=False,
num_workers=num_workers)
y_user = torch.zeros(g.number_of_nodes('user'), self.n_hidden)
y_website = torch.zeros(g.number_of_nodes('website'), self.n_hidden)
y_others = {ntype: torch.zeros(g.number_of_nodes(ntype), self.n_hidden)
for ntype in g.ntypes if ntype != 'user' and ntype != 'website'}
for input_nodes, output_nodes, blocks in dataloader:
block = blocks[0].to(device)
# get initial features
if l == 0:
u_f, w_f = user_features[input_nodes['user']], web_features[input_nodes['website']]
u_f, w_f = u_f.to(device), w_f.to(device)
user_nodes, website_nodes = input_nodes['user'].to(device), input_nodes['website'].to(device)
# get embeddings and concat with initial features
user_embed, website_embed = self.user_embedding(user_nodes), self.website_embedding(website_nodes)
u = torch.cat((user_embed, u_f), 1)
w = torch.cat((website_embed, w_f), 1)
# get intermediate representations
else:
u = y_user[input_nodes['user']].to(device)
w = y_website[input_nodes['website']].to(device)
h_dict = {}
h_dict['user'] = nn.Parameter(u)
h_dict['website'] = nn.Parameter(w)
for ntype in self.rgcn.embed:
if block.number_of_nodes(ntype) > 0:
if l == 0:
h_dict[ntype] = self.rgcn.embed[ntype][block.nodes(ntype).long(), :]
else:
h_dict[ntype] = y_others[ntype][input_nodes[ntype]].to(device)
h_dict = layer(block, h_dict)
if l != len(self.rgcn.layers) - 1:
h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()}
if len(output_nodes['user']):
y_user[output_nodes['user']] = h_dict['user'].cpu()
if len(output_nodes['website']):
y_website[output_nodes['website']] = h_dict['website'].cpu()
for ntype in self.rgcn.embed:
if len(output_nodes[ntype]):
y_others[ntype][output_nodes[ntype]] = h_dict[ntype].cpu()
return y_user