def __init__()

in graphlearn_torch/python/sampler/neighbor_sampler.py [0:0]


  def __init__(self,
               graph: Union[Graph, Dict[EdgeType, Graph]],
               num_neighbors: Optional[NumNeighbors] = None,
               device: torch.device=torch.device('cuda:0'),
               with_edge: bool=False,
               with_neg: bool=False,
               with_weight: bool=False,
               strategy: str = 'random',
               edge_dir: Literal['in', 'out'] = 'out',
               seed: int = None):
    self.graph = graph
    self.num_neighbors = num_neighbors
    self.device = device
    self.with_edge = with_edge
    self.with_neg = with_neg
    self.with_weight = with_weight
    self.strategy = strategy
    self.edge_dir = edge_dir
    self._subgraph_op = None
    self._sampler = None
    self._neg_sampler = None
    self._inducer = None
    self._sampler_lock = threading.Lock()
    self.is_sampler_initialized = False
    self.is_neg_sampler_initialized = False
    
    if seed is not None:
      pywrap.RandomSeedManager.getInstance().setSeed(seed)
    if isinstance(self.graph, Graph): #homo
      self._g_cls = 'homo'
      if self.graph.mode == 'CPU':
        self.device = torch.device('cpu')
    else: # hetero
      self._g_cls = 'hetero'
      self.edge_types = []
      self.node_types = set()
      for etype, graph in self.graph.items():
        self.edge_types.append(etype)
        self.node_types.add(etype[0])
        self.node_types.add(etype[2])
      if self.graph[self.edge_types[0]].mode == 'CPU':
        self.device = torch.device('cpu')
      self._set_num_neighbors_and_num_hops(self.num_neighbors)