in hugegraph-ml/src/hugegraph_ml/utils/dgl2hugegraph_utils.py [0:0]
def load_acm_raw():
# reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/han/utils.py
url = "dataset/ACM.mat"
data_path = get_download_dir() + "/ACM.mat"
if not os.path.exists(data_path):
print(f"File {data_path} not found, downloading...")
download(_get_dgl_url(url), path=data_path)
data = scipy.io.loadmat(data_path)
p_vs_l = data["PvsL"] # paper-field?
p_vs_a = data["PvsA"] # paper-author
p_vs_t = data["PvsT"] # paper-term, bag of words
p_vs_c = data["PvsC"] # paper-conference, labels come from that
# We assign
# (1) KDD papers as class 0 (data mining),
# (2) SIGMOD and VLDB papers as class 1 (database),
# (3) SIGCOMM and MOBICOMM papers as class 2 (communication)
conf_ids = [0, 1, 9, 10, 13]
label_ids = [0, 1, 2, 2, 1]
p_selected = p_vs_c[:, conf_ids].tocoo().row
p_vs_l = p_vs_l[p_selected]
p_vs_a = p_vs_a[p_selected]
p_vs_t = p_vs_t[p_selected]
p_vs_c = p_vs_c[p_selected]
hgraph = dgl.heterograph(
{
("paper", "pa", "author"): p_vs_a.nonzero(),
("author", "ap", "paper"): p_vs_a.transpose().nonzero(),
("paper", "pf", "field"): p_vs_l.nonzero(),
("field", "fp", "paper"): p_vs_l.transpose().nonzero(),
}
)
features = torch.FloatTensor(p_vs_t.toarray())
pc_p, pc_c = p_vs_c.nonzero()
labels = np.zeros(len(p_selected), dtype=np.int64)
for conf_id, label_id in zip(conf_ids, label_ids):
labels[pc_p[pc_c == conf_id]] = label_id
labels = torch.LongTensor(labels)
float_mask = np.zeros(len(pc_p))
for conf_id in conf_ids:
pc_c_mask = pc_c == conf_id
float_mask[pc_c_mask] = np.random.permutation(
np.linspace(0, 1, pc_c_mask.sum())
)
train_idx = np.where(float_mask <= 0.2)[0]
val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]
test_idx = np.where(float_mask > 0.3)[0]
num_nodes = hgraph.num_nodes("paper")
train_mask = _get_mask(num_nodes, train_idx)
val_mask = _get_mask(num_nodes, val_idx)
test_mask = _get_mask(num_nodes, test_idx)
hgraph.nodes["paper"].data["feat"] = features
hgraph.nodes["paper"].data["label"] = labels
hgraph.nodes["paper"].data["train_mask"] = train_mask
hgraph.nodes["paper"].data["val_mask"] = val_mask
hgraph.nodes["paper"].data["test_mask"] = test_mask
return hgraph