def __post_init__()

in para_graph_sampler/graph_engine/frontend/graph.py [0:0]


    def __post_init__(self, cap_node_full, cap_edge_full, cap_node_subg, cap_edge_subg, validate: bool):
        """
        All subgraphs sampled by the same sampler should have the same dtype, since cap_*_subg are an upper bound
        for all subgraphs under that sampler. 
        """
        if cap_node_full is not None and cap_edge_full is not None \
            and cap_node_subg is not None and cap_edge_subg is not None:
            dtype = {
                'indptr'   : np.int64,
                'indices'  : np.int64,
                'data'     : np.float32,
                'node'     : np.int64,
                'edge_index': np.int64,
                'target'   : np.int64
            }
            f_dtype = lambda n : np.uint16 if n < 2**16 else np.uint32
            if cap_node_full < 2**32:
                dtype['node'] = f_dtype(cap_node_full)
            if cap_edge_full < 2**32:
                dtype['edge_index'] = f_dtype(cap_edge_full)
            if cap_node_subg < 2**32:
                dtype['indices'] = f_dtype(cap_node_subg)
                dtype['target']  = f_dtype(cap_node_subg)
            if cap_edge_subg < 2**32:
                dtype['indptr']  = f_dtype(cap_edge_subg)
            assert set(dtype.keys()) == set(self.names_data_fields)
            for n in self.names_data_fields:
                v = getattr(self, n)
                if v is not None:
                    setattr(self, n, v.astype(dtype[n], copy=False))
            # explicitly handle data -- if it is all 1.
            if np.all(self.data == 1.):
                self.data = np.broadcast_to(np.array([1.]), self.data.size)
        if validate:
            self.check_valid()