in baseline_model/data_utils/train_tree_encoder.py [0:0]
def processing_data(cache_dir, iterators):
print("preprocessing data...")
for iterator in iterators:
for i, batch in init_tqdm(enumerate(iterator), 'preprocess'):
trg = batch.trg
id_elem = batch.id
path = os.path.join(cache_dir, str(id_elem[0]))
if not os.path.exists(path):
os.makedirs(path)
else:
continue
batch_size = len(trg)
queue_tree = {}
graphs = []
graphs_data = []
graphs_data_depth = []
graphs_data_encoding = []
total_tree_num = [1 for i in range(0,batch_size)]
for i in range(1, batch_size+1):
queue_tree[i] = []
queue_tree[i].append({"tree" : trg[i-1], "parent": 0, "child_index": 1 , "tree_path":[], "depth": 0, "child_num":len(trg[i-1].children), "encoding":[]})
total_tree_num[i-1]+= len(trg[i-1].children)
g = dgl.DGLGraph()
graphs.append(g)
graphs_data.append([])
graphs_data_depth.append([])
graphs_data_encoding.append([])
cur_index, max_index = 1,1
ic = 0
dict_info = {}
last_append = [None] * batch_size
while (cur_index <= max_index):
max_w_len = -1
max_w_len_path = -1
batch_w_list_trg = []
batch_w_list = []
flag = 1
for i in range(1, batch_size+1):
w_list_trg = []
if (cur_index <= len(queue_tree[i])):
t_node = queue_tree[i][cur_index - 1]
t_encode = t_node["encoding"]
t_depth = t_node["depth"]
t = t_node["tree"]
if ic == 0:
queue_tree[i][cur_index - 1]["tree_path"].append(t.value)
t_path = queue_tree[i][cur_index - 1]["tree_path"].copy()
if ic == 0 and cur_index == 1:
graphs[i-1].add_nodes(1)
graphs[i-1].add_edges(t_node["parent"], cur_index - 1)
graphs_data[i-1].append(t.value)
graphs_data_depth[i-1].append(t_depth)
graphs_data_encoding[i-1].append(t_encode)
elif (ic <= t_node['child_num'] - 1):
t_node_child = last_append[i-1]
graphs[i-1].add_nodes(1)
graphs[i-1].add_edges(t_node_child["parent"],len(queue_tree[i])-1)
graphs_data[i-1].append(t_node_child["tree"].value)
graphs_data_depth[i-1].append(t_node_child["depth"])
graphs_data_encoding[i-1].append(t_node_child["encoding"])
# if it is not expanding all the children, add children into the queue
if ic <= t_node['child_num'] - 1:
w_list_trg.append(t.children[ic].value)
encoding = get_novel_positional_encoding(t.children[ic], ic, t_node)
if(t.children[ic].value != 0):
last_append[i-1] = {"tree" : t.children[ic], "parent" : cur_index - 1, "child_index": ic, "tree_path" : t_path,
"depth" : t_depth + 1, "child_num": len(t.children[ic].children), "encoding" : encoding}
if len(t.children[ic].children) > 0:
queue_tree[i].append({"tree" : t.children[ic], "parent" : cur_index - 1, "child_index": ic, "tree_path":t_path, \
"depth" : t_depth + 1, "child_num": len(t.children[ic].children), "encoding":encoding})
if(ic + 1 < t_node['child_num']):
flag = 0
if len(queue_tree[i]) > max_index:
max_index = len(queue_tree[i])
if len(t_path) > max_w_len_path:
max_w_len_path = len(t_path)
if len(graphs_data[i-1]) > max_w_len:
max_w_len = len(graphs_data[i-1])
batch_w_list_trg.append(w_list_trg)
batch_w_list.append(t_path)
dict_info = {
'batch_w_list' : batch_w_list[0],
'batch_w_list_trg' : batch_w_list_trg[0],
'graphs': graphs[0],
"graph_data":torch.tensor(graphs_data[0]),
"graph_depth":torch.tensor(graphs_data_depth[0]),
"graph_data_encoding":graphs_data_encoding[0]
}
if batch_w_list_trg[0] == [] and ic == 0:
print(ic)
with open(os.path.join(path, str(cur_index)+'_'+str(ic)), 'wb') as f:
if dict_info =={}:
print(dict_info)
pickle.dump(dict_info, f)
cur_index = cur_index + flag
ic = 0 if flag == 1 else ic + 1