graphlearn_torch/python/distributed/dist_table_dataset.py (228 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eithPer express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ODPS table related distributed partitioner and dataset."""
import datetime
from multiprocessing.reduction import ForkingPickler
import numpy as np
import torch
import time
from typing import Dict, Optional, Union
try:
import common_io
except ImportError:
pass
from ..typing import (
NodeType, EdgeType, TensorDataType,
)
from .dist_dataset import DistDataset, _cat_feature_cache
from .dist_random_partitioner import DistRandomPartitioner
class DistTableRandomPartitioner(DistRandomPartitioner):
r""" A distributed random partitioner for parallel partitioning with large
scale edge tables and node tables.
Each distributed partitioner will process a slice of the full table,
and partition them in parallel. After partitioning, each distributed
partitioner will own a partitioned graph with its corresponding rank.
Args:
num_nodes: Number of all graph nodes, should be a dict for hetero data.
edge_index: A part of the edge index data of graph edges, should be a dict
for hetero data.
edge_ids: The edge ids of the input ``edge_index``.
node_feat: A part of the node feature data, should be a dict for hetero data.
node_feat_ids: The node ids corresponding to the input ``node_feat``.
edge_feat: The edge feature data, should be a dict for hetero data.
edge_feat_ids: The edge ids corresponding to the input ``edge_feat``.
num_parts: The number of all partitions. If not provided, the value of
``graphlearn_torch.distributed.get_context().world_size`` will be used.
current_partition_idx: The partition index corresponding to the current
distributed partitioner. If not provided, the value of
``graphlearn_torch.distributed.get_context().rank`` will be used.
node_feat_dtype: The data type of node features.
edge_feat_dtype: The data type of edge features.
edge_assign_strategy: The assignment strategy when partitioning edges,
should be 'by_src' or 'by_dst'.
chunk_size: The chunk size for partitioning.
master_addr: The master TCP address for RPC connection between all
distributed partitioners.
master_port: The master TCP port for RPC connection between all
distributed partitioners.
num_rpc_threads: The number of RPC worker threads to use.
Returns:
int: Number of all partitions.
int: The current partition idx.
GraphPartitionData/HeteroGraphPartitionData: graph partition data.
FeaturePartitionData/HeteroFeaturePartitionData: node feature partition
data, optional.
FeaturePartitionData/HeteroFeaturePartitionData: edge feature partition
data, optional.
PartitionBook/HeteroNodePartitionDict: node partition book.
PartitionBook/HeteroEdgePartitionDict: edge partition book.
"""
def __init__(
self,
num_nodes: Union[int, Dict[NodeType, int]],
edge_index: Union[TensorDataType, Dict[EdgeType, TensorDataType]],
edge_ids: Union[TensorDataType, Dict[EdgeType, TensorDataType]],
node_feat: Optional[Union[TensorDataType, Dict[NodeType, TensorDataType]]] = None,
node_feat_ids: Optional[Union[TensorDataType, Dict[NodeType, TensorDataType]]] = None,
edge_feat: Optional[Union[TensorDataType, Dict[EdgeType, TensorDataType]]] = None,
edge_feat_ids: Optional[Union[TensorDataType, Dict[EdgeType, TensorDataType]]] = None,
num_parts: Optional[int] = None,
current_partition_idx: Optional[int] = None,
node_feat_dtype: torch.dtype = torch.float32,
edge_feat_dtype: torch.dtype = torch.float32,
edge_assign_strategy: str = 'by_src',
chunk_size: int = 10000,
master_addr: Optional[str] = None,
master_port: Optional[str] = None,
num_rpc_threads: int = 16,
):
super().__init__('', num_nodes, edge_index, edge_ids, node_feat, node_feat_ids,
edge_feat, edge_feat_ids, num_parts, current_partition_idx,
node_feat_dtype, edge_feat_dtype, edge_assign_strategy,
chunk_size, master_addr, master_port, num_rpc_threads)
def partition(self):
r""" Partition graph and feature data into different parts along with all
other distributed partitioners, save the result of the current partition
index into output directory.
"""
if 'hetero' == self.data_cls:
node_pb_dict = {}
node_feat_dict = {}
for ntype in self.node_types:
node_pb = self._partition_node(ntype)
node_pb_dict[ntype] = node_pb
current_node_feat_part = self._partition_node_feat(node_pb, ntype)
if current_node_feat_part is not None:
node_feat_dict[ntype] = current_node_feat_part
edge_pb_dict = {}
graph_dict = {}
edge_feat_dict = {}
for etype in self.edge_types:
current_graph_part, edge_pb = self._partition_graph(node_pb_dict, etype)
edge_pb_dict[etype] = edge_pb
graph_dict[etype] = current_graph_part
current_edge_feat_part = self._partition_edge_feat(edge_pb, etype)
if current_edge_feat_part is not None:
edge_feat_dict[etype] = current_edge_feat_part
return (
self.num_parts, self.current_partition_idx,
graph_dict, node_feat_dict, edge_feat_dict, node_pb_dict, edge_pb_dict
)
else:
node_pb = self._partition_node()
node_feat = self._partition_node_feat(node_pb)
graph, edge_pb = self._partition_graph(node_pb)
edge_feat = self._partition_edge_feat(edge_pb)
return (
self.num_parts, self.current_partition_idx,
graph, node_feat, edge_feat, node_pb, edge_pb
)
class DistTableDataset(DistDataset):
""" Creates `DistDataset` from ODPS tables.
Args:
edge_tables: A dict({edge_type : odps_table}) denoting each
bipartite graph input table of heterogeneous graph, where edge_type is
a tuple of (src_type, edge_type, dst_type).
node_tables: A dict({node_type(str) : odps_table}) denoting each
input node table.
num_nodes: Number of all graph nodes, should be a dict for hetero data.
graph_mode: mode in graphlearn_torch's `Graph`, 'CPU', 'ZERO_COPY'
or 'CUDA'.
sort_func: function for feature reordering, return feature data(2D tenosr)
and a map(1D tensor) from id to index.
split_ratio: The proportion of data allocated to the GPU, between 0 and 1.
device_group_list: A list of `DeviceGroup`. Each DeviceGroup must have the
same size. A group of GPUs with peer-to-peer access to each other should
be set in the same device group for high feature collection performance.
directed: A Boolean value indicating whether the graph topology is
directed.
reader_threads: The number of threads of table reader.
reader_capacity: The capacity of table reader.
reader_batch_size: The number of records read at once.
label: A CPU torch.Tensor(homo) or a Dict[NodeType, torch.Tensor](hetero)
with the label data for graph nodes.
device: The target cuda device rank to perform graph operations and
feature lookups.
feature_with_gpu (bool): A Boolean value indicating whether the created
``Feature`` objects of node/edge features use ``UnifiedTensor``.
If True, it means ``Feature`` consists of ``UnifiedTensor``, otherwise
``Feature`` is a PyTorch CPU Tensor, the ``device_group_list`` and
``device`` will be invliad. (default: ``True``)
edge_assign_strategy: The assignment strategy when partitioning edges,
should be 'by_src' or 'by_dst'.
chunk_size: The chunk size for partitioning.
master_addr: The master TCP address for RPC connection between all
distributed partitioners.
master_port: The master TCP port for RPC connection between all
distributed partitioners.
num_rpc_threads: The number of RPC worker threads to use.
"""
def load(
self,
num_partitions=1,
partition_idx=0,
edge_tables=None,
node_tables=None,
num_nodes=0,
graph_mode='ZERO_COPY',
device_group_list=None,
reader_threads=10,
reader_capacity=10240,
reader_batch_size=1024,
label=None,
device=None,
feature_with_gpu=True,
edge_assign_strategy='by_src',
chunk_size=10000,
master_addr=None,
master_port=None,
num_rpc_threads=16,
):
assert isinstance(edge_tables, dict)
assert isinstance(node_tables, dict)
edge_index, eids, feature = {}, {}, {}
edge_hetero = (len(edge_tables) > 1)
node_hetero = (len(node_tables) > 1)
print("Start Loading edge and node tables...")
step = 0
start_time = time.time()
for e_type, table in edge_tables.items():
edge_list = []
reader = common_io.table.TableReader(table,
slice_id=partition_idx,
slice_count=num_partitions,
num_threads=reader_threads,
capacity=reader_capacity)
while True:
try:
data = reader.read(reader_batch_size, allow_smaller_final_batch=True)
edge_list.extend(data)
step += 1
except common_io.exception.OutOfRangeException:
reader.close()
break
if step % 1000 == 0:
print(f"{datetime.datetime.now()}: load "
f"{step * reader_batch_size} edges.")
rows = [e[0] for e in edge_list]
cols = [e[1] for e in edge_list]
eids_array = np.array([e[2] for e in edge_list], dtype=np.int64)
edge_array = np.stack([np.array(rows, dtype=np.int64),
np.array(cols, dtype=np.int64)])
if edge_hetero:
edge_index[e_type] = eids_array
eids[e_type] = eids
else:
edge_index = edge_array
eids = eids_array
del rows
del cols
del edge_list
step = 0
for n_type, table in node_tables.items():
feature_list = []
reader = common_io.table.TableReader(table,
slice_id=partition_idx,
slice_count=num_partitions,
num_threads=reader_threads,
capacity=reader_capacity)
while True:
try:
data = reader.read(reader_batch_size, allow_smaller_final_batch=True)
feature_list.extend(data)
step += 1
except common_io.exception.OutOfRangeException:
reader.close()
break
if step % 1000 == 0:
print(f"{datetime.datetime.now()}: load "
f"{step * reader_batch_size} nodes.")
ids = torch.tensor([feat[0] for feat in feature_list], dtype=torch.long)
if isinstance(feature_list[0][1], bytes):
float_feat= [
list(map(float, feat[1].decode().split(':')))
for feat in feature_list
]
else:
float_feat= [
list(map(float, feat[1].split(':')))
for feat in feature_list
]
if node_hetero:
feature[n_type] = torch.tensor(float_feat)
else:
feature = torch.tensor(float_feat)
del float_feat
del feature_list
load_time = (time.time() - start_time) / 60
print(f'Loading table completed in {load_time:.2f} minutes.')
print("Start partitioning graph and feature...")
p_start = time.time()
dist_partitioner = DistTableRandomPartitioner(
num_nodes, edge_index=edge_index, edge_ids=eids,
node_feat=feature, node_feat_ids=ids,
num_parts=num_partitions, current_partition_idx=partition_idx,
edge_assign_strategy=edge_assign_strategy,
chunk_size=chunk_size, master_addr=master_addr, master_port=master_port,
num_rpc_threads=num_rpc_threads)
(
self.num_partitions,
self.partition_idx,
graph_data,
node_feat_data,
edge_feat_data,
self.node_pb,
self.edge_pb
) = dist_partitioner.partition()
part_time = (time.time() - p_start) / 60
print(f'Partitioning completed in {part_time:.2f} minutes.')
# init graph
if isinstance(graph_data, dict):
# heterogeneous.
edge_index, edge_ids = {}, {}
for k, v in graph_data.items():
edge_index[k] = v.edge_index
edge_ids[k] = v.eids
else:
# homogeneous.
edge_index = graph_data.edge_index
edge_ids = graph_data.eids
self.init_graph(edge_index, edge_ids, layout='COO',
graph_mode=graph_mode, device=device)
# load node feature
if node_feat_data is not None:
node_cache_ratio, node_feat, node_feat_id2idx, node_feat_pb = \
_cat_feature_cache(partition_idx, node_feat_data, self.node_pb)
self.init_node_features(
node_feat, node_feat_id2idx, None, node_cache_ratio,
device_group_list, device, feature_with_gpu, dtype=None
)
self._node_feat_pb = node_feat_pb
# load edge feature
if edge_feat_data is not None:
edge_cache_ratio, edge_feat, edge_feat_id2idx, edge_feat_pb = \
_cat_feature_cache(partition_idx, edge_feat_data, self.edge_pb)
self.init_edge_features(
edge_feat, edge_feat_id2idx, edge_cache_ratio,
device_group_list, device, feature_with_gpu, dtype=None
)
self._edge_feat_pb = edge_feat_pb
# load whole node labels
self.init_node_labels(label)
## Pickling Registration
def rebuild_dist_table_dataset(ipc_handle):
ds = DistTableDataset.from_ipc_handle(ipc_handle)
return ds
def reduce_dist_table_dataset(dataset: DistTableDataset):
ipc_handle = dataset.share_ipc()
return (rebuild_dist_table_dataset, (ipc_handle, ))
ForkingPickler.register(DistTableDataset, reduce_dist_table_dataset)