def get_communities()

in hugegraph-ml/src/hugegraph_ml/models/pgnn.py [0:0]


def get_communities(remove_feature, graph):
    community_size = 20
    # Randomly rewire 1% edges
    node_list = list(graph.nodes)
    for u, v in graph.edges():
        if random.random() < 0.01:
            x = random.choice(node_list)
            if graph.has_edge(u, x):
                continue
            graph.remove_edge(u, v)
            graph.add_edge(u, x)

    # remove self-loops
    graph.remove_edges_from(nx.selfloop_edges(graph))
    edge_index = np.array(list(graph.edges))
    # Add (i, j) for an edge (j, i)
    edge_index = np.concatenate((edge_index, edge_index[:, ::-1]), axis=0)
    edge_index = torch.from_numpy(edge_index).long().permute(1, 0)

    n = graph.number_of_nodes()
    label = np.zeros((n, n), dtype=int)
    for u in node_list:
        # the node IDs are simply consecutive integers from 0
        for v in range(u):
            if u // community_size == v // community_size:
                label[u, v] = 1

    if remove_feature:
        feature = torch.ones((n, 1))
    else:
        rand_order = np.random.permutation(n)
        feature = np.identity(n)[:, rand_order]

    data = {
        "edge_index": edge_index,
        "feature": feature,
        "positive_edges": np.stack(np.nonzero(label)),
        "num_nodes": feature.shape[0],
    }

    return data