in para_graph_sampler/graph_engine/frontend/graph.py [0:0]
def __post_init__(self, cap_node_full, cap_edge_full, cap_node_subg, cap_edge_subg, validate: bool):
"""
All subgraphs sampled by the same sampler should have the same dtype, since cap_*_subg are an upper bound
for all subgraphs under that sampler.
"""
if cap_node_full is not None and cap_edge_full is not None \
and cap_node_subg is not None and cap_edge_subg is not None:
dtype = {
'indptr' : np.int64,
'indices' : np.int64,
'data' : np.float32,
'node' : np.int64,
'edge_index': np.int64,
'target' : np.int64
}
f_dtype = lambda n : np.uint16 if n < 2**16 else np.uint32
if cap_node_full < 2**32:
dtype['node'] = f_dtype(cap_node_full)
if cap_edge_full < 2**32:
dtype['edge_index'] = f_dtype(cap_edge_full)
if cap_node_subg < 2**32:
dtype['indices'] = f_dtype(cap_node_subg)
dtype['target'] = f_dtype(cap_node_subg)
if cap_edge_subg < 2**32:
dtype['indptr'] = f_dtype(cap_edge_subg)
assert set(dtype.keys()) == set(self.names_data_fields)
for n in self.names_data_fields:
v = getattr(self, n)
if v is not None:
setattr(self, n, v.astype(dtype[n], copy=False))
# explicitly handle data -- if it is all 1.
if np.all(self.data == 1.):
self.data = np.broadcast_to(np.array([1.]), self.data.size)
if validate:
self.check_valid()