def load_acm_raw()

in hugegraph-ml/src/hugegraph_ml/utils/dgl2hugegraph_utils.py [0:0]


def load_acm_raw():
    # reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/han/utils.py
    url = "dataset/ACM.mat"
    data_path = get_download_dir() + "/ACM.mat"
    if not os.path.exists(data_path):
        print(f"File {data_path} not found, downloading...")
        download(_get_dgl_url(url), path=data_path)

    data = scipy.io.loadmat(data_path)
    p_vs_l = data["PvsL"]  # paper-field?
    p_vs_a = data["PvsA"]  # paper-author
    p_vs_t = data["PvsT"]  # paper-term, bag of words
    p_vs_c = data["PvsC"]  # paper-conference, labels come from that

    # We assign
    # (1) KDD papers as class 0 (data mining),
    # (2) SIGMOD and VLDB papers as class 1 (database),
    # (3) SIGCOMM and MOBICOMM papers as class 2 (communication)
    conf_ids = [0, 1, 9, 10, 13]
    label_ids = [0, 1, 2, 2, 1]

    p_selected = p_vs_c[:, conf_ids].tocoo().row
    p_vs_l = p_vs_l[p_selected]
    p_vs_a = p_vs_a[p_selected]
    p_vs_t = p_vs_t[p_selected]
    p_vs_c = p_vs_c[p_selected]

    hgraph = dgl.heterograph(
        {
            ("paper", "pa", "author"): p_vs_a.nonzero(),
            ("author", "ap", "paper"): p_vs_a.transpose().nonzero(),
            ("paper", "pf", "field"): p_vs_l.nonzero(),
            ("field", "fp", "paper"): p_vs_l.transpose().nonzero(),
        }
    )

    features = torch.FloatTensor(p_vs_t.toarray())

    pc_p, pc_c = p_vs_c.nonzero()
    labels = np.zeros(len(p_selected), dtype=np.int64)
    for conf_id, label_id in zip(conf_ids, label_ids):
        labels[pc_p[pc_c == conf_id]] = label_id
    labels = torch.LongTensor(labels)

    float_mask = np.zeros(len(pc_p))
    for conf_id in conf_ids:
        pc_c_mask = pc_c == conf_id
        float_mask[pc_c_mask] = np.random.permutation(
            np.linspace(0, 1, pc_c_mask.sum())
        )
    train_idx = np.where(float_mask <= 0.2)[0]
    val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]
    test_idx = np.where(float_mask > 0.3)[0]

    num_nodes = hgraph.num_nodes("paper")
    train_mask = _get_mask(num_nodes, train_idx)
    val_mask = _get_mask(num_nodes, val_idx)
    test_mask = _get_mask(num_nodes, test_idx)

    hgraph.nodes["paper"].data["feat"] = features
    hgraph.nodes["paper"].data["label"] = labels
    hgraph.nodes["paper"].data["train_mask"] = train_mask
    hgraph.nodes["paper"].data["val_mask"] = val_mask
    hgraph.nodes["paper"].data["test_mask"] = test_mask

    return hgraph