def inference()

in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/model.py [0:0]


    def inference(self, g, user_features, web_features, batch_size, n_neighbors, device, num_workers=0):
        for l, layer in enumerate(self.rgcn.layers):
            sampler = dgl.dataloading.MultiLayerNeighborSampler([n_neighbors])
            dataloader = dgl.dataloading.NodeDataLoader(
                g,
                {ntype: torch.arange(g.number_of_nodes(ntype)) for ntype in g.ntypes},
                sampler,
                batch_size= batch_size,
                shuffle=True,
                drop_last=False,
                num_workers=num_workers)

            y_user =  torch.zeros(g.number_of_nodes('user'), self.n_hidden)
            y_website = torch.zeros(g.number_of_nodes('website'), self.n_hidden)
            y_others = {ntype: torch.zeros(g.number_of_nodes(ntype), self.n_hidden)
                        for ntype in g.ntypes if ntype != 'user' and ntype != 'website'}

            for input_nodes, output_nodes, blocks in dataloader:
                block = blocks[0].to(device)

                # get initial features
                if l == 0:
                    u_f, w_f = user_features[input_nodes['user']], web_features[input_nodes['website']]
                    u_f, w_f = u_f.to(device), w_f.to(device)
                    user_nodes, website_nodes = input_nodes['user'].to(device), input_nodes['website'].to(device)

                    # get embeddings and concat with initial features
                    user_embed, website_embed = self.user_embedding(user_nodes), self.website_embedding(website_nodes)
                    u = torch.cat((user_embed, u_f), 1)
                    w = torch.cat((website_embed, w_f), 1)

                # get intermediate representations
                else:
                    u = y_user[input_nodes['user']].to(device)
                    w = y_website[input_nodes['website']].to(device)

                h_dict = {}
                h_dict['user'] = nn.Parameter(u)
                h_dict['website'] = nn.Parameter(w)

                for ntype in self.rgcn.embed:
                    if block.number_of_nodes(ntype) > 0:
                        if l == 0:
                            h_dict[ntype] = self.rgcn.embed[ntype][block.nodes(ntype).long(), :]
                        else:
                            h_dict[ntype] = y_others[ntype][input_nodes[ntype]].to(device)

                h_dict = layer(block, h_dict)
                if l != len(self.rgcn.layers) - 1:
                    h_dict = {k: F.leaky_relu(h) for k, h in h_dict.items()}
                if len(output_nodes['user']):
                    y_user[output_nodes['user']] = h_dict['user'].cpu()
                if len(output_nodes['website']):
                    y_website[output_nodes['website']] = h_dict['website'].cpu()
                for ntype in self.rgcn.embed:
                    if len(output_nodes[ntype]):
                        y_others[ntype][output_nodes[ntype]] = h_dict[ntype].cpu()

        return y_user