in graphlearn_torch/python/distributed/dist_feature.py [0:0]
def __init__(self,
num_partitions: int,
partition_idx: int,
local_feature: Union[Feature,
Dict[Union[NodeType, EdgeType], Feature]],
feature_pb: Union[PartitionBook,
HeteroNodePartitionDict,
HeteroEdgePartitionDict],
local_only: bool = False,
rpc_router: Optional[RpcDataPartitionRouter] = None,
device: Optional[torch.device] = None):
self.num_partitions = num_partitions
self.partition_idx = partition_idx
self.device = get_available_device(device)
ensure_device(self.device)
self.local_feature = local_feature
if isinstance(self.local_feature, dict):
self.data_cls = 'hetero'
for _, feat in self.local_feature.items():
if isinstance(feat, Feature):
feat.lazy_init_with_ipc_handle()
elif isinstance(self.local_feature, Feature):
self.data_cls = 'homo'
self.local_feature.lazy_init_with_ipc_handle()
else:
raise ValueError(f"'{self.__class__.__name__}': found invalid input "
f"feature type '{type(self.local_feature)}'")
self.feature_pb = feature_pb
if isinstance(self.feature_pb, dict):
assert self.data_cls == 'hetero'
for key, feat in self.feature_pb.items():
if not isinstance(feat, PartitionBook):
self.feature_pb[key] = GLTPartitionBook(feat)
elif isinstance(self.feature_pb, PartitionBook):
assert self.data_cls == 'homo'
elif isinstance(self.feature_pb, torch.Tensor):
assert self.data_cls == 'homo'
self.feature_pb = GLTPartitionBook(self.feature_pb)
else:
raise ValueError(f"'{self.__class__.__name__}': found invalid input "
f"patition book type '{type(self.feature_pb)}'")
self.rpc_router = rpc_router
if not local_only:
if self.rpc_router is None:
raise ValueError(f"'{self.__class__.__name__}': a rpc router must be "
f"provided when `local_only` set to `False`")
rpc_callee = RpcFeatureLookupCallee(self)
self.rpc_callee_id = rpc_register(rpc_callee)
else:
self.rpc_callee_id = None