def generate_index()

in pycls/models/relation_graph.py [0:0]


def generate_index(message_type='ba', n=16, sparsity=0.5, p=0.2,
                   directed=False, seed=123):
    degree = n * sparsity
    known_names = ['mcwhole', 'mcwholeraw', 'mcvisual', 'mcvisualraw', 'cat', 'catraw']
    if message_type == 'er':
        graph = nx.gnm_random_graph(n=n, m=n * degree // 2, seed=seed)
    elif message_type == 'random':
        edge_num = int(n * n * sparsity)
        edge_id = np.random.choice(n * n, edge_num, replace=False)
        edge_index = np.zeros((edge_num, 2), dtype=int)
        for i in range(edge_num):
            edge_index[i, 0] = edge_id[i] // n
            edge_index[i, 1] = edge_id[i] % n
    elif message_type == 'ws':
        graph = connected_ws_graph(n=n, k=degree, p=p, seed=seed)
    elif message_type == 'ba':
        graph = nx.barabasi_albert_graph(n=n, m=degree // 2, seed=seed)
    elif message_type == 'hypercube':
        graph = nx.hypercube_graph(n=int(np.log2(n)))
    elif message_type == 'grid':
        m = degree
        n = n // degree
        graph = nx.grid_2d_graph(m=m, n=n)
    elif message_type == 'cycle':
        graph = nx.cycle_graph(n=n)
    elif message_type == 'tree':
        graph = nx.random_tree(n=n, seed=seed)
    elif message_type == 'regular':
        graph = nx.connected_watts_strogatz_graph(n=n, k=degree, p=0, seed=seed)
    elif message_type in known_names:
        graph = load_graph(message_type)
        edge_index = nx_to_edge(graph, directed=True, seed=seed)
    else:
        raise NotImplementedError
    if message_type != 'random' and message_type not in known_names:
        edge_index = nx_to_edge(graph, directed=directed, seed=seed)
    return edge_index