graphlearn_torch/python/sampler/neighbor_sampler.py (561 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import math from typing import Dict, Optional, Union, Literal import torch import threading from .. import py_graphlearn_torch as pywrap from ..data import Graph from ..typing import NodeType, EdgeType, NumNeighbors, reverse_edge_type from ..utils import ( merge_dict, merge_hetero_sampler_output, format_hetero_sampler_output, id2idx, count_dict ) from .base import ( BaseSampler, EdgeIndex, NodeSamplerInput, EdgeSamplerInput, SamplerOutput, HeteroSamplerOutput, NeighborOutput, ) from .negative_sampler import RandomNegativeSampler class NeighborSampler(BaseSampler): r""" Neighbor Sampler. """ def __init__(self, graph: Union[Graph, Dict[EdgeType, Graph]], num_neighbors: Optional[NumNeighbors] = None, device: torch.device=torch.device('cuda:0'), with_edge: bool=False, with_neg: bool=False, with_weight: bool=False, strategy: str = 'random', edge_dir: Literal['in', 'out'] = 'out', seed: int = None): self.graph = graph self.num_neighbors = num_neighbors self.device = device self.with_edge = with_edge self.with_neg = with_neg self.with_weight = with_weight self.strategy = strategy self.edge_dir = edge_dir self._subgraph_op = None self._sampler = None self._neg_sampler = None self._inducer = None self._sampler_lock = threading.Lock() self.is_sampler_initialized = False self.is_neg_sampler_initialized = False if seed is not None: pywrap.RandomSeedManager.getInstance().setSeed(seed) if isinstance(self.graph, Graph): #homo self._g_cls = 'homo' if self.graph.mode == 'CPU': self.device = torch.device('cpu') else: # hetero self._g_cls = 'hetero' self.edge_types = [] self.node_types = set() for etype, graph in self.graph.items(): self.edge_types.append(etype) self.node_types.add(etype[0]) self.node_types.add(etype[2]) if self.graph[self.edge_types[0]].mode == 'CPU': self.device = torch.device('cpu') self._set_num_neighbors_and_num_hops(self.num_neighbors) @property def subgraph_op(self): self.lazy_init_subgraph_op() return self._subgraph_op def lazy_init_sampler(self): if not self.is_sampler_initialized: with self._sampler_lock: if self._sampler is None: if self._g_cls == 'homo': if self.device.type == 'cuda': self._sampler = pywrap.CUDARandomSampler(self.graph.graph_handler) elif self.with_weight == False: self._sampler = pywrap.CPURandomSampler(self.graph.graph_handler) else: self._sampler = pywrap.CPUWeightedSampler(self.graph.graph_handler) self.is_sampler_initialized = True else: # hetero self._sampler = {} for etype, g in self.graph.items(): if self.device != torch.device('cpu'): self._sampler[etype] = pywrap.CUDARandomSampler(g.graph_handler) elif self.with_weight == False: self._sampler[etype] = pywrap.CPURandomSampler(g.graph_handler) else: self._sampler[etype] = pywrap.CPUWeightedSampler(g.graph_handler) self.is_sampler_initialized = True def lazy_init_neg_sampler(self): if not self.is_neg_sampler_initialized and self.with_neg: with self._sampler_lock: if self._neg_sampler is None: if self._g_cls == 'homo': self._neg_sampler = RandomNegativeSampler( graph=self.graph, mode=self.device.type.upper(), edge_dir=self.edge_dir ) self.is_neg_sampler_initialized = True else: # hetero self._neg_sampler = {} for etype, g in self.graph.items(): self._neg_sampler[etype] = RandomNegativeSampler( graph=g, mode=self.device.type.upper(), edge_dir=self.edge_dir ) self.is_neg_sampler_initialized = True def lazy_init_subgraph_op(self): if self._subgraph_op is None: with self._sampler_lock: if self._subgraph_op is None: if self.device.type == 'cuda': self._subgraph_op = pywrap.CUDASubGraphOp(self.graph.graph_handler) else: self._subgraph_op = pywrap.CPUSubGraphOp(self.graph.graph_handler) def sample_one_hop( self, input_seeds: torch.Tensor, req_num: int, etype: EdgeType = None ) -> NeighborOutput: self.lazy_init_sampler() sampler = self._sampler[etype] if etype is not None else self._sampler input_seeds = input_seeds.to(self.device) edge_ids = None if not self.with_edge: nbrs, nbrs_num = sampler.sample(input_seeds, req_num) else: nbrs, nbrs_num, edge_ids = sampler.sample_with_edge(input_seeds, req_num) if nbrs.numel() == 0: nbrs = torch.tensor([], dtype=torch.int64 ,device=self.device) nbrs_num = torch.zeros_like(input_seeds, dtype=torch.int64, device=self.device) edge_ids = torch.tensor([], device=self.device, dtype=torch.int64) \ if self.with_edge else None return NeighborOutput(nbrs, nbrs_num, edge_ids) def sample_from_nodes( self, inputs: NodeSamplerInput, **kwargs ) -> Union[HeteroSamplerOutput, SamplerOutput]: inputs = NodeSamplerInput.cast(inputs) input_seeds = inputs.node.to(self.device) input_type = inputs.input_type if self._g_cls == 'hetero': assert input_type is not None output = self._hetero_sample_from_nodes({input_type: input_seeds}) else: output = self._sample_from_nodes(input_seeds) return output def _sample_from_nodes( self, input_seeds: torch.Tensor ) -> SamplerOutput: r""" Sample on homogenous graphs and induce COO format subgraph. Note that messages in PyG are passed from src to dst. In 'out' direction, we sample src's out neighbors and induce [src_index, dst_index] subgraphs. The direction of sampling is opposite to the direction of message passing. To be consistent with the semantics of PyG, the final edge index is transpose to [dst_index, src_index]. In 'in' direction, we don't need to reverse it. """ out_nodes, out_rows, out_cols, out_edges = [], [], [], [] num_sampled_nodes, num_sampled_edges = [], [] inducer = self.get_inducer(input_seeds.numel()) srcs = inducer.init_node(input_seeds) batch = srcs num_sampled_nodes.append(input_seeds.numel()) out_nodes.append(srcs) for req_num in self.num_neighbors: out_nbrs = self.sample_one_hop(srcs, req_num) if out_nbrs.nbr.numel() == 0: break nodes, rows, cols = inducer.induce_next( srcs, out_nbrs.nbr, out_nbrs.nbr_num) out_nodes.append(nodes) out_rows.append(rows) out_cols.append(cols) if out_nbrs.edge is not None: out_edges.append(out_nbrs.edge) num_sampled_nodes.append(nodes.size(0)) num_sampled_edges.append(cols.size(0)) srcs = nodes return SamplerOutput( node=torch.cat(out_nodes), row=torch.cat(out_cols) if len(out_cols) > 0 else torch.tensor(out_cols), col=torch.cat(out_rows) if len(out_rows) > 0 else torch.tensor(out_rows), edge=(torch.cat(out_edges) if out_edges else None), batch=batch, num_sampled_nodes=num_sampled_nodes, num_sampled_edges=num_sampled_edges, device=self.device ) def _hetero_sample_from_nodes( self, input_seeds_dict: Dict[NodeType, torch.Tensor], ) -> HeteroSamplerOutput: r""" Sample on heterogenous graphs and induce COO format subgraph dict. Note that messages in PyG are passed from src to dst. In 'out' direction, we sample src's out neighbors and induce [src_index, dst_index] subgraphs. The direction of sampling is opposite to the direction of message passing. To be consistent with the semantics of PyG, the final edge index is transpose to [dst_index, src_index] and edge_type is reversed as well. For example, given the edge_type (u, u2i, i), we sample by meta-path u->i, but return edge_index_dict {(i, rev_u2i, u) : [i, u]}. In 'in' direction, we don't need to reverse it. """ # sample neighbors hop by hop. max_input_batch_size = max([t.numel() for t in input_seeds_dict.values()]) inducer = self.get_inducer(max_input_batch_size) src_dict = inducer.init_node(input_seeds_dict) batch = src_dict out_nodes, out_rows, out_cols, out_edges = {}, {}, {}, {} num_sampled_nodes, num_sampled_edges = {}, {} merge_dict(src_dict, out_nodes) count_dict(src_dict, num_sampled_nodes, 1) for i in range(self.num_hops): nbr_dict, edge_dict = {}, {} for etype in self.edge_types: req_num = self.num_neighbors[etype][i] # out sampling needs dst_type==seed_type, in sampling needs src_type==seed_type if self.edge_dir == 'in': src = src_dict.get(etype[-1], None) if src is not None and src.numel() > 0: output = self.sample_one_hop(src, req_num, etype) if output.nbr.numel() == 0: continue nbr_dict[reverse_edge_type(etype)] = [src, output.nbr, output.nbr_num] if output.edge is not None: edge_dict[reverse_edge_type(etype)] = output.edge elif self.edge_dir == 'out': src = src_dict.get(etype[0], None) if src is not None and src.numel() > 0: output = self.sample_one_hop(src, req_num, etype) if output.nbr.numel() == 0: continue nbr_dict[etype] = [src, output.nbr, output.nbr_num] if output.edge is not None: edge_dict[etype] = output.edge if len(nbr_dict) == 0: continue nodes_dict, rows_dict, cols_dict = inducer.induce_next(nbr_dict) merge_dict(nodes_dict, out_nodes) merge_dict(rows_dict, out_rows) merge_dict(cols_dict, out_cols) merge_dict(edge_dict, out_edges) count_dict(nodes_dict, num_sampled_nodes, i + 2) count_dict(cols_dict, num_sampled_edges, i + 1) src_dict = nodes_dict for etype, rows in out_rows.items(): out_rows[etype] = torch.cat(rows) out_cols[etype] = torch.cat(out_cols[etype]) if self.with_edge: out_edges[etype] = torch.cat(out_edges[etype]) res_rows, res_cols, res_edges = {}, {}, {} for etype, rows in out_rows.items(): rev_etype = reverse_edge_type(etype) res_rows[rev_etype] = out_cols[etype] res_cols[rev_etype] = rows if self.with_edge: res_edges[rev_etype] = out_edges[etype] return HeteroSamplerOutput( node={k : torch.cat(v) for k, v in out_nodes.items()}, row=res_rows, col=res_cols, edge=(res_edges if len(res_edges) else None), batch=batch, num_sampled_nodes={k : torch.tensor(v, device=self.device) for k, v in num_sampled_nodes.items()}, num_sampled_edges={ reverse_edge_type(k) : torch.tensor(v, device=self.device) for k, v in num_sampled_edges.items()}, edge_types=self.edge_types, device=self.device ) def sample_from_edges( self, inputs: EdgeSamplerInput, **kwargs, ) -> Union[HeteroSamplerOutput, SamplerOutput]: r"""Performs sampling from an edge sampler input, leveraging a sampling function of the same signature as `node_sample`. Note that in out-edge sampling, we reverse the direction of src and dst for the output so that features of the sampled nodes during training can be aggregated from k-hop to (k-1)-hop nodes. """ src = inputs.row.to(self.device) dst = inputs.col.to(self.device) edge_label = None if inputs.label is None else inputs.label.to(self.device) input_type = inputs.input_type neg_sampling = inputs.neg_sampling num_pos = src.numel() num_neg = 0 # Negative Sampling self.lazy_init_neg_sampler() if neg_sampling is not None: # When we are doing negative sampling, we append negative information # of nodes/edges to `src`, `dst`. # Later on, we can easily reconstruct what belongs to positive and # negative examples by slicing via `num_pos`. num_neg = math.ceil(num_pos * neg_sampling.amount) if neg_sampling.is_binary(): # In the "binary" case, we randomly sample negative pairs of nodes. if input_type is not None: neg_pair = self._neg_sampler[input_type].sample(num_neg) else: neg_pair = self._neg_sampler.sample(num_neg) src_neg, dst_neg = neg_pair[0], neg_pair[1] src = torch.cat([src, src_neg], dim=0) dst = torch.cat([dst, dst_neg], dim=0) if edge_label is None: edge_label = torch.ones(num_pos, device=self.device) size = (src_neg.size()[0], ) + edge_label.size()[1:] edge_neg_label = edge_label.new_zeros(size) edge_label = torch.cat([edge_label, edge_neg_label]) elif neg_sampling.is_triplet(): # TODO: make triplet negative sampling strict. # In the "triplet" case, we randomly sample negative destinations # in a "non-strict" manner. assert num_neg % num_pos == 0 if input_type is not None: neg_pair = self._neg_sampler[input_type].sample(num_neg, padding=True) else: neg_pair = self._neg_sampler.sample(num_neg, padding=True) dst_neg = neg_pair[1] dst = torch.cat([dst, dst_neg], dim=0) assert edge_label is None # Neighbor Sampling if input_type is not None: # hetero if input_type[0] != input_type[-1]: # Two distinct node types: src_seed, dst_seed = src, dst src, inverse_src = src.unique(return_inverse=True) dst, inverse_dst = dst.unique(return_inverse=True) seed_dict = {input_type[0]: src, input_type[-1]: dst} else: # Only a single node type: Merge both source and destination. seed = torch.cat([src, dst], dim=0) seed, inverse_seed = seed.unique(return_inverse=True) seed_dict = {input_type[0]: seed} temp_out = [] for it, node in seed_dict.items(): seeds = NodeSamplerInput(node=node, input_type=it) temp_out.append(self.sample_from_nodes(seeds)) if len(temp_out) == 2: out = merge_hetero_sampler_output(temp_out[0], temp_out[1], device=self.device, edge_dir=self.edge_dir) else: out = format_hetero_sampler_output(temp_out[0], edge_dir=self.edge_dir) # edge_label if neg_sampling is None or neg_sampling.is_binary(): if input_type[0] != input_type[-1]: inverse_src = id2idx(out.node[input_type[0]])[src_seed] inverse_dst = id2idx(out.node[input_type[-1]])[dst_seed] edge_label_index = torch.stack([ inverse_src, inverse_dst, ], dim=0) else: edge_label_index = inverse_seed.view(2, -1) out.metadata = {'edge_label_index': edge_label_index, 'edge_label': edge_label} out.input_type = input_type elif neg_sampling.is_triplet(): if input_type[0] != input_type[-1]: inverse_src = id2idx(out.node[input_type[0]])[src_seed] inverse_dst = id2idx(out.node[input_type[-1]])[dst_seed] src_index = inverse_src dst_pos_index = inverse_dst[:num_pos] dst_neg_index = inverse_dst[num_pos:] else: src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = {'src_index': src_index, 'dst_pos_index': dst_pos_index, 'dst_neg_index': dst_neg_index} out.input_type = input_type else: #homo seed = torch.cat([src, dst], dim=0) seed, inverse_seed = seed.unique(return_inverse=True) out = self.sample_from_nodes(seed) # edge_label if neg_sampling is None or neg_sampling.is_binary(): edge_label_index = inverse_seed.view(2, -1) out.metadata = {'edge_label_index': edge_label_index, 'edge_label': edge_label} elif neg_sampling.is_triplet(): src_index = inverse_seed[:num_pos] dst_pos_index = inverse_seed[num_pos:2 * num_pos] dst_neg_index = inverse_seed[2 * num_pos:] dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1) out.metadata = {'src_index': src_index, 'dst_pos_index': dst_pos_index, 'dst_neg_index': dst_neg_index} return out def sample_pyg_v1(self, ids: torch.Tensor): r""" Sample multi-hop neighbors and organize results to PyG's `EdgeIndex`. Args: ids: input ids, 1D tensor. The sampled results that is the same as PyG's `NeighborSampler`(PyG v1) """ ids = ids.to(self.device) adjs = [] srcs = ids out_ids = ids batch_size = 0 inducer = self.get_inducer(srcs.numel()) for i, req_num in enumerate(self.num_neighbors): srcs = inducer.init_node(srcs) batch_size = srcs.numel() if i == 0 else batch_size out_nbrs = self.sample_one_hop(srcs, req_num) nodes, rows, cols = \ inducer.induce_next(srcs, out_nbrs.nbr, out_nbrs.nbr_num) edge_index = torch.stack([cols, rows]) # we use csr instead of csc in PyG. out_ids = torch.cat([srcs, nodes]) adj_size = torch.LongTensor([out_ids.size(0), srcs.size(0)]) adjs.append(EdgeIndex(edge_index, out_nbrs.edge, adj_size)) srcs = out_ids return batch_size, out_ids, adjs[::-1] def subgraph( self, inputs: NodeSamplerInput, ) -> SamplerOutput: self.lazy_init_subgraph_op() inputs = NodeSamplerInput.cast(inputs) input_seeds = inputs.node.to(self.device) if self.num_neighbors is not None: nodes = [input_seeds] for num in self.num_neighbors: nbr = self.sample_one_hop(nodes[-1], num).nbr nodes.append(torch.unique(nbr)) nodes, mapping = torch.cat(nodes).unique(return_inverse=True) else: nodes, mapping = torch.unique(input_seeds, return_inverse=True) subgraph = self._subgraph_op.node_subgraph(nodes, self.with_edge) return SamplerOutput( node=subgraph.nodes, # The edge index should be reversed. row=subgraph.cols, col=subgraph.rows, edge=subgraph.eids if self.with_edge else None, device=self.device, metadata=mapping[:input_seeds.numel()]) def sample_prob( self, inputs: NodeSamplerInput, node_cnt: Union[int, Dict[NodeType, int]] ) -> Union[torch.Tensor, Dict[NodeType, torch.Tensor]]: r""" Get the probability of each node being sampled. """ self.lazy_init_sampler() inputs = NodeSamplerInput.cast(inputs) input_seeds = inputs.node.to(self.device) input_type = inputs.input_type if self._g_cls == 'hetero': assert input_type is not None output = self._hetero_sample_prob({input_type : input_seeds}, node_cnt) else: output = self._sample_prob(input_seeds, node_cnt) return output def _sample_prob( self, input_seeds: torch.Tensor, node_cnt: int ) -> torch.Tensor: last_prob = \ torch.ones(node_cnt, device=self.device, dtype=torch.float32) * 0.01 last_prob[input_seeds] = 1 for req in self.num_neighbors: cur_prob = torch.zeros(node_cnt, device=self.device, dtype=torch.float32) self._sampler.cal_nbr_prob( req, last_prob, last_prob, self.graph.graph_handler, cur_prob ) last_prob = cur_prob return last_prob def _hetero_sample_prob( self, input_seeds_dict: Dict[NodeType, torch.Tensor], node_dict: Dict[NodeType, int] ) -> Dict[NodeType, torch.Tensor]: probs = {} for ntype in node_dict.keys(): probs[ntype] = [] # calculate probs for each subgraph for i in range(self.num_hops): for etype in self.edge_types: req = self.num_neighbors[etype][i] # homogenous subgraph case if etype[0] == etype[2]: if len(probs[etype[0]]) == 0: last_prob = torch.ones(node_dict[etype[0]].size(0), device=self.device, dtype=torch.float32) * 0.005 last_prob[input_seeds_dict[etype[0]]] = 1 else: last_prob = self.aggregate_prob(probs[etype[0]], node_dict[etype[0]].size(0), device=self.device) cur_prob = torch.zeros(node_dict[etype[0]].size(0), device=self.device, dtype=torch.float32) self._sampler[etype].cal_nbr_prob( req, last_prob, last_prob, self._graph_dict[etype].graph_handler, cur_prob ) last_prob = cur_prob probs[etype[0]].append(last_prob) # hetero bipartite graph case else: if len(probs[etype[0]]) == 0: last_prob = torch.ones(node_dict[etype[0]].size(0), device=self.device, dtype=torch.float32) * 0.005 last_prob[input_seeds_dict[etype[0]]] = 1 else: last_prob = self.aggregate_prob(probs[etype[0]], node_dict[etype[0]].size(0), device=self.device) etypes = [nbr_etype for nbr_etype in self.edge_types if nbr_etype[0] == etype[2]] temp_probs = [] # prepare nbr_prob if len(probs[etype[2]]) == 0: nbr_prob = torch.ones(node_dict[etype[2]].size(0), device=self.device, dtype=torch.float32) * 0.005 if etype[2] in input_seeds_dict: nbr_prob[input_seeds_dict[etype[2]]] = 1 else: nbr_prob = self.aggregate_prob(probs[etype[2]], node_dict[etype[2]].size(0), device=self.device) for nbr_etype in etypes: cur_prob = torch.zeros(node_dict[etype[0]].size(0), device=self.device, dtype=torch.float32) self._sampler[etype].cal_nbr_prob( req, last_prob, nbr_prob, self._graph_dict[nbr_etype].graph_handler, cur_prob ) last_prob = cur_prob temp_probs.append(last_prob) # aggregate prob for the bipartite graph # with #{subgraphs where the neighbours are} sub_temp_prob = self.aggregate_prob(temp_probs, node_dict[etype[0]].size(0), device=self.device) probs[etype[0]].append(sub_temp_prob) # aggregate probs from each subgraph # with #{subgraphs} for ntype, prob in probs.items(): res = self.aggregate_prob( prob, node_dict[ntype].size(0), device=self.device) if i == self.num_hops - 1: probs[ntype] = res else: probs[ntype] = [res] return probs def get_inducer(self, input_batch_size: int): if self._inducer is None: self._inducer = self.create_inducer(input_batch_size) return self._inducer def create_inducer(self, input_batch_size: int): max_num_nodes = self._max_sampled_nodes(input_batch_size) if self.device.type == 'cuda': if self._g_cls == 'homo': inducer = pywrap.CUDAInducer(max_num_nodes) else: inducer = pywrap.CUDAHeteroInducer(max_num_nodes) else: if self._g_cls == 'homo': inducer = pywrap.CPUInducer(max_num_nodes) else: inducer = pywrap.CPUHeteroInducer(max_num_nodes) return inducer def _set_num_neighbors_and_num_hops(self, num_neighbors): if isinstance(num_neighbors, (list, tuple)): num_neighbors = {key: num_neighbors for key in self.edge_types} assert isinstance(num_neighbors, dict) self.num_neighbors = num_neighbors # Add at least one element to the list to ensure `max` is well-defined self.num_hops = max([0] + [len(v) for v in num_neighbors.values()]) for key, value in self.num_neighbors.items(): if len(value) != self.num_hops: raise ValueError(f"Expected the edge type {key} to have " f"{self.num_hops} entries (got {len(value)})") def _max_sampled_nodes( self, input_batch_size: int, ) -> Union[int, Dict[str, int]]: if self._g_cls == 'homo': res = [input_batch_size] for num in self.num_neighbors: res.append(res[-1] * num) return sum(res) res = {k : [] for k in self.node_types} for etype, num_list in self.num_neighbors.items(): tmp_res = [input_batch_size] for num in num_list: tmp_res.append(tmp_res[-1] * num) res[etype[0]].extend(tmp_res) res[etype[2]].extend(tmp_res) return {k : sum(v) for k, v in res.items()} def _aggregate_prob(self, probs, node_num, device): """ Aggregate probs from each subgraph p = 1 - ((1-p_0)(1-p_1)...(1-p_k))**(1/k) where k := #{subgraphs} """ res = torch.ones(node_num, device=device, dtype=torch.float32) for temp_prob in probs: # to avoid the case that p_i=1 causes p=1 s.t the whole importance won't # be decided by one term. res *= (1 + .002 - temp_prob) res = 1 - res ** (1/len(probs)) return res.clamp(min=0.0)