in graphlearn_torch/python/distributed/dist_neighbor_sampler.py [0:0]
def __init__(self,
data: DistDataset,
num_neighbors: Optional[NumNeighbors] = None,
with_edge: bool = False,
with_neg: bool = False,
with_weight: bool = False,
edge_dir: Literal['in', 'out'] = 'out',
collect_features: bool = False,
channel: Optional[ChannelBase] = None,
use_all2all: bool = False,
concurrency: int = 1,
device: Optional[torch.device] = None,
seed:int = None):
self.data = data
self.use_all2all = use_all2all
self.num_neighbors = num_neighbors
self.max_input_size = 0
self.with_edge = with_edge
self.with_neg = with_neg
self.with_weight = with_weight
self.edge_dir = edge_dir
self.collect_features = collect_features
self.channel = channel
self.concurrency = concurrency
self.device = get_available_device(device)
self.seed = seed
if isinstance(data, DistDataset):
partition2workers = rpc_sync_data_partitions(
num_data_partitions=self.data.num_partitions,
current_partition_idx=self.data.partition_idx
)
self.rpc_router = RpcDataPartitionRouter(partition2workers)
self.dist_graph = DistGraph(
data.num_partitions, data.partition_idx,
data.graph, data.node_pb, data.edge_pb
)
self.dist_node_feature = None
self.dist_edge_feature = None
if self.collect_features:
if data.node_features is not None:
self.dist_node_feature = DistFeature(
data.num_partitions, data.partition_idx,
data.node_features, data.node_feat_pb,
local_only=False, rpc_router=self.rpc_router, device=self.device
)
if self.with_edge and data.edge_features is not None:
self.dist_edge_feature = DistFeature(
data.num_partitions, data.partition_idx,
data.edge_features, data.edge_feat_pb,
local_only=False, rpc_router=self.rpc_router, device=self.device
)
# dist_node_labels should is initialized as a DistFeature object in the v6d case
self.dist_node_labels = self.data.node_labels
if self.dist_graph.data_cls == 'homo':
if self.dist_node_labels is not None and \
not isinstance(self.dist_node_labels, torch.Tensor):
self.dist_node_labels = DistFeature(
self.data.num_partitions, self.data.partition_idx,
self.dist_node_labels, self.data.node_feat_pb,
local_only=False, rpc_router=self.rpc_router, device=self.device
)
else:
assert self.dist_node_labels is None or isinstance(self.dist_node_labels, Dict)
if self.dist_node_labels is not None and \
all(isinstance(value, Feature) for value in self.dist_node_labels.values()):
self.dist_node_labels = DistFeature(
self.data.num_partitions, self.data.partition_idx,
self.data.node_labels, self.data.node_feat_pb,
local_only=False, rpc_router=self.rpc_router, device=self.device
)
else:
raise ValueError(f"'{self.__class__.__name__}': found invalid input "
f"data type '{type(data)}'")
self.sampler = NeighborSampler(
self.dist_graph.local_graph, self.num_neighbors,
self.device, self.with_edge, self.with_neg, self.with_weight,
self.edge_dir, seed=self.seed
)
self.inducer_pool = queue.Queue(maxsize=self.concurrency)
# rpc register
rpc_sample_callee = RpcSamplingCallee(self.sampler, self.device)
self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)
rpc_subgraph_callee = RpcSubGraphCallee(self.sampler, self.device)
self.rpc_subgraph_callee_id = rpc_register(rpc_subgraph_callee)
if self.dist_graph.data_cls == 'hetero':
self.num_neighbors = self.sampler.num_neighbors
self.num_hops = self.sampler.num_hops
self.edge_types = self.sampler.edge_types
super().__init__(self.concurrency)
self._loop.call_soon_threadsafe(ensure_device, self.device)