def build_partition_feature()

in graphlearn_torch/python/partition/base.py [0:0]


def build_partition_feature(
  root_dir: str, 
  partition_idx: int,
  chunk_size: int = 10000,
  node_feat: Optional[Union[TensorDataType, Dict[NodeType, TensorDataType]]] = None,
  node_feat_dtype: torch.dtype = torch.float32,
  edge_feat: Optional[Union[TensorDataType, Dict[EdgeType, TensorDataType]]] = None,
  edge_feat_dtype: torch.dtype = torch.float32):
  
  r""" In the case that the graph topology is partitioned, but the feature
       partitioning is not executed. This method extracts and persist the 
       feature for a specific partition. 
  
  Args:
    root_dir (str): The root directory for saved partition files.
    partition_idx (int): The partition idx.
    chunk_size: The chunk size for partitioning.
    node_feat: The node feature data, should be a dict for hetero data.
    node_feat_dtype: The data type of node features.
    edge_feat: The edge feature data, should be a dict for hetero data.
    edge_feat_dtype: The data type of edge features.

  """
  with open(os.path.join(root_dir, 'META'), 'rb') as infile:
    meta = pickle.load(infile)
  num_partitions = meta['num_parts']
  assert partition_idx >= 0
  assert partition_idx < num_partitions
  partition_dir = os.path.join(root_dir, f'part{partition_idx}')
  assert os.path.exists(partition_dir)
  graph_dir = os.path.join(partition_dir, 'graph')
  device = torch.device('cpu')

  node_feat = convert_to_tensor(node_feat, dtype=node_feat_dtype)
  edge_feat = convert_to_tensor(edge_feat, dtype=edge_feat_dtype)
  
  # homogenous
  if meta['data_cls'] == 'homo':
    # step 1: build and persist the node feature partition
    node_pb = torch.load(os.path.join(root_dir, 'node_pb.pt'), 
      map_location=device)
    node_num = node_pb.size(0)
    ids = torch.arange(node_num, dtype=torch.int64)
    mask = (node_pb == partition_idx)
    n_ids = torch.masked_select(ids, mask)
    # save partitioned node feature chunk
    n_ids_chunks = torch.chunk(n_ids,
      chunks=((n_ids.shape[0] + chunk_size - 1) // chunk_size))
    for chunk in n_ids_chunks:
      p_node_feat_chunk = FeaturePartitionData(
        feats=node_feat[chunk],
        ids=chunk.clone(),
        cache_feats=None,
        cache_ids=None
      )
      save_feature_partition_chunk(root_dir, partition_idx, p_node_feat_chunk,
                                   group='node_feat', graph_type=None)
    
    # step 2: build and persist the edge feature partition
    if edge_feat is None:
      return
    graph = load_graph_partition_data(graph_dir, device)
    eids = graph.eids
    eids_chunks = torch.chunk(
      eids, chunks=((eids.shape[0] + chunk_size - 1) // chunk_size)
    )
    for chunk in eids_chunks:
      p_edge_feat_chunk = FeaturePartitionData(
        feats=edge_feat[chunk],
        ids=chunk.clone(),
        cache_feats=None,
        cache_ids=None
      )
      save_feature_partition_chunk(root_dir, partition_idx, p_edge_feat_chunk,
                                   group='edge_feat', graph_type=None)
  # heterogenous
  else:  
    # step 1: build and persist the node feature partition
    node_pb_dir = os.path.join(root_dir, 'node_pb')
    for ntype in node_feat.keys():
      node_pb = torch.load(
        os.path.join(node_pb_dir, f'{as_str(ntype)}.pt'), map_location=device)
      feat = node_feat[ntype]
      node_num = node_pb.size(0)
      ids = torch.arange(node_num, dtype=torch.int64)
      mask = (node_pb == partition_idx)
      n_ids = torch.masked_select(ids, mask)
      # save partitioned node feature chunk
      n_ids_chunks = torch.chunk(n_ids, 
        chunks=((n_ids.shape[0] + chunk_size - 1) // chunk_size))
      for chunk in n_ids_chunks:
        p_node_feat_chunk = FeaturePartitionData(
          feats=feat[chunk],
          ids=chunk.clone(),
          cache_feats=None,
          cache_ids=None
        )
        save_feature_partition_chunk(root_dir, partition_idx, p_node_feat_chunk,
                                    group='node_feat', graph_type=ntype)
    # step 2: build and persist the edge feature partition
    if edge_feat is None:
        return
    for etype in edge_feat.keys():
      feat = edge_feat[etype]
      graph = load_graph_partition_data(
        os.path.join(graph_dir, as_str(etype)), device)
      eids = graph.eids
      eids_chunks = torch.chunk(
        eids, chunks=((eids.shape[0] + chunk_size - 1) // chunk_size)
      )
      for chunk in eids_chunks:
        p_edge_feat_chunk = FeaturePartitionData(
          feats=feat[chunk],
          ids=chunk.clone(),
          cache_feats=None,
          cache_ids=None
        )
        save_feature_partition_chunk(root_dir, partition_idx, p_edge_feat_chunk,
                                    group='edge_feat', graph_type=etype)