graphlearn_torch/python/partition/frequency_partitioner.py (147 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== from typing import List, Dict, Optional, Tuple, Union import torch from ..typing import NodeType, EdgeType, TensorDataType from ..utils import parse_size from .base import PartitionerBase, PartitionBook class FrequencyPartitioner(PartitionerBase): r""" Frequency-based partitioner for graph topology and features. Args: output_dir: The output root directory for partitioned results. num_parts: Number of partitions. num_nodes: Number of graph nodes, should be a dict for hetero data. edge_index: The edge index data of graph edges, should be a dict for hetero data. probs: The node access distribution on each partition, should be a dict for hetero data. node_feat: The node feature data, should be a dict for hetero data. node_feat_dtype: The data type of node features. edge_feat: The edge feature data, should be a dict for hetero data. edge_feat_dtype: The data type of edge features. edge_weights: The edge weights, should be a dict for hetero data. edge_assign_strategy: The assignment strategy when partitioning edges, should be 'by_src' or 'by_dst'. cache_memory_budget: The memory budget (in bytes) for cached node features per partition for each node type, should be a dict for hetero data. cache_ratio: The proportion to cache node features per partition for each node type, should be a dict for hetero data. chunk_size: The chunk size for partitioning. Note that if both `cache_memory_budget` and `cache_ratio` are provided, the metric that caches the smaller number of features will be used. If both of them set to empty dict, the feature cache will be turned off. """ def __init__( self, output_dir: str, num_parts: int, num_nodes: Union[int, Dict[NodeType, int]], edge_index: Union[TensorDataType, Dict[EdgeType, TensorDataType]], probs: Union[List[torch.Tensor], Dict[NodeType, List[torch.Tensor]]], node_feat: Optional[Union[TensorDataType, Dict[NodeType, TensorDataType]]] = None, node_feat_dtype: torch.dtype = torch.float32, edge_feat: Optional[Union[TensorDataType, Dict[EdgeType, TensorDataType]]] = None, edge_feat_dtype: torch.dtype = torch.float32, edge_weights: Optional[Union[TensorDataType, Dict[EdgeType, TensorDataType]]] = None, edge_assign_strategy: str = 'by_src', cache_memory_budget: Union[int, Dict[NodeType, int]] = None, cache_ratio: Union[float, Dict[NodeType, float]] = None, chunk_size: int = 10000, ): super().__init__(output_dir, num_parts, num_nodes, edge_index, node_feat, node_feat_dtype, edge_feat, edge_feat_dtype, edge_weights, edge_assign_strategy, chunk_size) self.probs = probs if self.node_feat is not None: if 'hetero' == self.data_cls: self.per_feature_bytes = {} for ntype, feat in self.node_feat.items(): assert len(feat.shape) == 2 self.per_feature_bytes[ntype] = feat.shape[1] * feat.element_size() assert isinstance(self.probs, dict) for ntype, prob_list in self.probs.items(): assert ntype in self.node_types assert len(prob_list) == self.num_parts else: assert len(self.node_feat.shape) == 2 self.per_feature_bytes = (self.node_feat.shape[1] * self.node_feat.element_size()) assert len(self.probs) == self.num_parts self.blob_size = self.chunk_size * self.num_parts if cache_memory_budget is None: self.cache_memory_budget = {} if 'hetero' == self.data_cls else 0 else: self.cache_memory_budget = cache_memory_budget if cache_ratio is None: self.cache_ratio = {} if 'hetero' == self.data_cls else 0.0 else: self.cache_ratio = cache_ratio def _get_chunk_probs_sum( self, chunk: torch.Tensor, probs: List[torch.Tensor] ) -> List[torch.Tensor]: r""" Helper function for partitioning a certain type of node to calculate hotness and difference between partitions. """ chunk_probs_sum = [ (torch.zeros(chunk.size(0)) + 1e-6) for _ in range(self.num_parts) ] for src_rank in range(self.num_parts): for dst_rank in range(self.num_parts): if dst_rank == src_rank: chunk_probs_sum[src_rank] += probs[dst_rank][chunk] * self.num_parts else: chunk_probs_sum[src_rank] -= probs[dst_rank][chunk] return chunk_probs_sum def _partition_node( self, ntype: Optional[NodeType] = None ) -> Tuple[List[torch.Tensor], PartitionBook]: if 'hetero' == self.data_cls: assert ntype is not None node_num = self.num_nodes[ntype] probs = self.probs[ntype] else: node_num = self.num_nodes probs = self.probs chunk_num = (node_num + self.chunk_size - 1) // self.chunk_size res = [[] for _ in range(self.num_parts)] current_chunk_start_pos = 0 current_partition_idx = 0 for _ in range(chunk_num): current_chunk_end_pos = min(node_num, current_chunk_start_pos + self.blob_size) current_chunk_size = current_chunk_end_pos - current_chunk_start_pos chunk = torch.arange(current_chunk_start_pos, current_chunk_end_pos, dtype=torch.long) chunk_probs_sum = self._get_chunk_probs_sum(chunk, probs) assigned_node_size = 0 per_partition_size = self.chunk_size for partition_idx in range(current_partition_idx, current_partition_idx + self.num_parts): partition_idx = partition_idx % self.num_parts actual_per_partition_size = min(per_partition_size, chunk.size(0) - assigned_node_size) _, sorted_res_order = torch.sort(chunk_probs_sum[partition_idx], descending=True) pick_chunk_part = sorted_res_order[:actual_per_partition_size] pick_ids = chunk[pick_chunk_part] res[partition_idx].append(pick_ids) for idx in range(self.num_parts): chunk_probs_sum[idx][pick_chunk_part] = -self.num_parts assigned_node_size += actual_per_partition_size current_partition_idx += 1 current_chunk_start_pos += current_chunk_size partition_book = torch.zeros(node_num, dtype=torch.long) partition_results = [] for partition_idx in range(self.num_parts): partition_ids = torch.cat(res[partition_idx]) partition_results.append(partition_ids) partition_book[partition_ids] = partition_idx return partition_results, partition_book def _cache_node( self, ntype: Optional[NodeType] = None ) -> List[Optional[torch.Tensor]]: if 'hetero' == self.data_cls: assert ntype is not None probs = self.probs[ntype] per_feature_bytes = self.per_feature_bytes[ntype] cache_memory_budget = self.cache_memory_budget.get(ntype, 0) cache_ratio = self.cache_ratio.get(ntype, 0.0) else: probs = self.probs per_feature_bytes = self.per_feature_bytes cache_memory_budget = self.cache_memory_budget cache_ratio = self.cache_ratio cache_memory_budget_bytes = parse_size(cache_memory_budget) cache_num_by_memory = int(cache_memory_budget_bytes / (per_feature_bytes + 1e-6)) cache_num_by_memory = min(cache_num_by_memory, probs[0].size(0)) cache_num_by_ratio = int(probs[0].size(0) * min(cache_ratio, 1.0)) if cache_num_by_memory == 0: cache_num = cache_num_by_ratio elif cache_num_by_ratio == 0: cache_num = cache_num_by_memory else: cache_num = min(cache_num_by_memory, cache_num_by_ratio) cache_results = [None] * self.num_parts if cache_num > 0: for partition_idx in range(self.num_parts): _, prev_order = torch.sort(probs[partition_idx], descending=True) cache_results[partition_idx] = prev_order[:cache_num] return cache_results