def load_training_data_gatne()

in hugegraph-ml/src/hugegraph_ml/utils/dgl2hugegraph_utils.py [0:0]


def load_training_data_gatne():
    # reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/GATNE-T/src/utils.py
    # reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/GATNE-T/src/main.py
    f_name = "dataset/amazon/train.txt"
    print("We are loading data from:", f_name)
    edge_data_by_type = dict()
    with open(f_name, "r") as f:
        for line in f:
            words = line[:-1].split(" ")  # line[-1] == '\n'
            if words[0] not in edge_data_by_type:
                edge_data_by_type[words[0]] = list()
            x, y = words[1], words[2]
            edge_data_by_type[words[0]].append((x, y))
    nodes, index2word = [], []
    for edge_type in edge_data_by_type:
        node1, node2 = zip(*edge_data_by_type[edge_type])
        index2word = index2word + list(node1) + list(node2)
    index2word = list(set(index2word))
    vocab = {}
    i = 0
    for word in index2word:
        vocab[word] = i
        i = i + 1
    for edge_type in edge_data_by_type:
        node1, node2 = zip(*edge_data_by_type[edge_type])
        tmp_nodes = list(set(list(node1) + list(node2)))
        tmp_nodes = [vocab[word] for word in tmp_nodes]
        nodes.append(tmp_nodes)
    node_type = "_N"  # '_N' can be replaced by an arbitrary name
    data_dict = dict()
    num_nodes_dict = {node_type: len(vocab)}
    for edge_type in edge_data_by_type:
        tmp_data = edge_data_by_type[edge_type]
        src = []
        dst = []
        for edge in tmp_data:
            src.extend([vocab[edge[0]], vocab[edge[1]]])
            dst.extend([vocab[edge[1]], vocab[edge[0]]])
        data_dict[(node_type, edge_type, node_type)] = (src, dst)
    graph = dgl.heterograph(data_dict, num_nodes_dict)
    return graph