in graphlearn_torch/python/sampler/neighbor_sampler.py [0:0]
def __init__(self,
graph: Union[Graph, Dict[EdgeType, Graph]],
num_neighbors: Optional[NumNeighbors] = None,
device: torch.device=torch.device('cuda:0'),
with_edge: bool=False,
with_neg: bool=False,
with_weight: bool=False,
strategy: str = 'random',
edge_dir: Literal['in', 'out'] = 'out',
seed: int = None):
self.graph = graph
self.num_neighbors = num_neighbors
self.device = device
self.with_edge = with_edge
self.with_neg = with_neg
self.with_weight = with_weight
self.strategy = strategy
self.edge_dir = edge_dir
self._subgraph_op = None
self._sampler = None
self._neg_sampler = None
self._inducer = None
self._sampler_lock = threading.Lock()
self.is_sampler_initialized = False
self.is_neg_sampler_initialized = False
if seed is not None:
pywrap.RandomSeedManager.getInstance().setSeed(seed)
if isinstance(self.graph, Graph): #homo
self._g_cls = 'homo'
if self.graph.mode == 'CPU':
self.device = torch.device('cpu')
else: # hetero
self._g_cls = 'hetero'
self.edge_types = []
self.node_types = set()
for etype, graph in self.graph.items():
self.edge_types.append(etype)
self.node_types.add(etype[0])
self.node_types.add(etype[2])
if self.graph[self.edge_types[0]].mode == 'CPU':
self.device = torch.device('cpu')
self._set_num_neighbors_and_num_hops(self.num_neighbors)