in hugegraph-ml/src/hugegraph_ml/data/hugegraph2dgl.py [0:0]
def _convert_graph_from_v_e(vertices, edges, feat_key=None, label_key=None, mask_keys=None):
if len(vertices) == 0:
warnings.warn("This graph has no vertices", Warning)
return dgl.graph(())
vertex_ids = [v["id"] for v in vertices]
vertex_id_to_idx = {vertex_id: idx for idx, vertex_id in enumerate(vertex_ids)}
src_idx = [vertex_id_to_idx[e["outV"]] for e in edges]
dst_idx = [vertex_id_to_idx[e["inV"]] for e in edges]
graph_dgl = dgl.graph((src_idx, dst_idx))
if feat_key and feat_key in vertices[0]["properties"]:
node_feats = [v["properties"][feat_key] for v in vertices]
graph_dgl.ndata["feat"] = torch.tensor(node_feats, dtype=torch.float32)
if label_key and label_key in vertices[0]["properties"]:
node_labels = [v["properties"][label_key] for v in vertices]
graph_dgl.ndata["label"] = torch.tensor(node_labels, dtype=torch.long)
if mask_keys:
for mk in mask_keys:
if mk in vertices[0]["properties"]:
node_masks = [v["properties"][mk] for v in vertices]
mask = torch.tensor(node_masks, dtype=torch.bool)
graph_dgl.ndata[mk] = mask
return graph_dgl