in hugegraph-ml/src/hugegraph_ml/utils/dgl2hugegraph_utils.py [0:0]
def load_training_data_gatne():
# reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/GATNE-T/src/utils.py
# reference: https://github.com/dmlc/dgl/blob/master/examples/pytorch/GATNE-T/src/main.py
f_name = "dataset/amazon/train.txt"
print("We are loading data from:", f_name)
edge_data_by_type = dict()
with open(f_name, "r") as f:
for line in f:
words = line[:-1].split(" ") # line[-1] == '\n'
if words[0] not in edge_data_by_type:
edge_data_by_type[words[0]] = list()
x, y = words[1], words[2]
edge_data_by_type[words[0]].append((x, y))
nodes, index2word = [], []
for edge_type in edge_data_by_type:
node1, node2 = zip(*edge_data_by_type[edge_type])
index2word = index2word + list(node1) + list(node2)
index2word = list(set(index2word))
vocab = {}
i = 0
for word in index2word:
vocab[word] = i
i = i + 1
for edge_type in edge_data_by_type:
node1, node2 = zip(*edge_data_by_type[edge_type])
tmp_nodes = list(set(list(node1) + list(node2)))
tmp_nodes = [vocab[word] for word in tmp_nodes]
nodes.append(tmp_nodes)
node_type = "_N" # '_N' can be replaced by an arbitrary name
data_dict = dict()
num_nodes_dict = {node_type: len(vocab)}
for edge_type in edge_data_by_type:
tmp_data = edge_data_by_type[edge_type]
src = []
dst = []
for edge in tmp_data:
src.extend([vocab[edge[0]], vocab[edge[1]]])
dst.extend([vocab[edge[1]], vocab[edge[0]]])
data_dict[(node_type, edge_type, node_type)] = (src, dst)
graph = dgl.heterograph(data_dict, num_nodes_dict)
return graph