def cat_to_block_diagonal()

in para_graph_sampler/graph_engine/frontend/graph.py [0:0]


    def cat_to_block_diagonal(cls, subgs: list):
        """ Concatenate subgraphs into a full adj matrix (i.e., into the block diagonal form) """
        offset_indices = np.cumsum([s.node.size for s in subgs])            # always int64
        offset_indptr = np.cumsum([s.edge_index.size for s in subgs])       # ^
        offset_indices[1:] = offset_indices[:-1]
        offset_indices[0] = 0
        offset_indptr[1:] = offset_indptr[:-1]
        offset_indptr[0] = 0
        node_batch = np.concatenate([s.node for s in subgs])                # keep original dtype
        edge_index_batch = np.concatenate([s.edge_index for s in subgs])    # ^
        data_batch = np.concatenate([s.data for s in subgs])                # ^
        target_batch_itr  = [s.target.astype(np.int64) for s in subgs]
        indptr_batch_itr  = [s.indptr.astype(np.int64) for s in subgs]
        indices_batch_itr = [s.indices.astype(np.int64) for s in subgs]
        target_batch, indptr_batch, indices_batch = [], [], []
        for i in range(len(subgs)):
            target_batch.append(target_batch_itr[i] + offset_indices[i])
            if i > 0:       # end of indptr1 equals beginning of indptr2. So remove one duplicate to ensure correctness. 
                indptr_batch_itr[i] = indptr_batch_itr[i][1:]
            indptr_batch.append(indptr_batch_itr[i] + offset_indptr[i])
            indices_batch.append(indices_batch_itr[i] + offset_indices[i])
        target_batch = np.concatenate(target_batch)
        indptr_batch = np.concatenate(indptr_batch)
        indices_batch = np.concatenate(indices_batch)
        entity_enc_batch = EntityEncoding.cat_batch(subgs)
        ret_subg = cls(
            indptr=indptr_batch, 
            indices=indices_batch,
            data=data_batch, 
            node=node_batch,
            edge_index=edge_index_batch,
            target=target_batch,
            entity_enc=entity_enc_batch,
            cap_node_full=2**63,        # just be safe. Note that concated subgraphs are only used for one batch. 
            cap_edge_full=2**63,
            cap_node_subg=2**63,
            cap_edge_subg=2**63,
            validate=True
        )
        return ret_subg