def processing_data()

in baseline_model/data_utils/train_tree_encoder.py [0:0]


def processing_data(cache_dir, iterators):
    print("preprocessing data...")
    for iterator in iterators:
        for i, batch in init_tqdm(enumerate(iterator), 'preprocess'): 
            trg = batch.trg
            id_elem  = batch.id

            path = os.path.join(cache_dir, str(id_elem[0]))
            if not os.path.exists(path):
                os.makedirs(path)
            else:
                continue
            batch_size = len(trg)
            queue_tree = {}
            graphs = []
            graphs_data = []
            graphs_data_depth = []
            graphs_data_encoding = []

            total_tree_num = [1 for i in range(0,batch_size)]
            for i in range(1, batch_size+1):
                queue_tree[i] = []
                queue_tree[i].append({"tree" : trg[i-1], "parent": 0, "child_index": 1 , "tree_path":[], "depth": 0, "child_num":len(trg[i-1].children), "encoding":[]})
                total_tree_num[i-1]+= len(trg[i-1].children)
                g = dgl.DGLGraph()
                graphs.append(g)
                graphs_data.append([])
                graphs_data_depth.append([])
                graphs_data_encoding.append([])

            cur_index, max_index = 1,1
            ic = 0
            dict_info = {}
            last_append = [None] * batch_size

            while (cur_index <= max_index):
                max_w_len = -1
                max_w_len_path = -1
                batch_w_list_trg = []
                batch_w_list = []
                flag = 1
                for i in range(1, batch_size+1):
                    w_list_trg = []
                    if (cur_index <= len(queue_tree[i])):
                        t_node = queue_tree[i][cur_index - 1]
                        t_encode = t_node["encoding"]
                        t_depth = t_node["depth"]
                        t = t_node["tree"]
                        if ic == 0:
                            queue_tree[i][cur_index - 1]["tree_path"].append(t.value)
                        t_path = queue_tree[i][cur_index - 1]["tree_path"].copy()

                        if ic == 0 and cur_index == 1:
                            graphs[i-1].add_nodes(1)
                            graphs[i-1].add_edges(t_node["parent"], cur_index - 1)
                            graphs_data[i-1].append(t.value)
                            graphs_data_depth[i-1].append(t_depth)
                            graphs_data_encoding[i-1].append(t_encode)
    
                        elif (ic <= t_node['child_num'] - 1):
                            t_node_child = last_append[i-1]
                            graphs[i-1].add_nodes(1)
                            graphs[i-1].add_edges(t_node_child["parent"],len(queue_tree[i])-1)
                            graphs_data[i-1].append(t_node_child["tree"].value)
                            graphs_data_depth[i-1].append(t_node_child["depth"])
                            graphs_data_encoding[i-1].append(t_node_child["encoding"])

                        # if it is not expanding all the children, add children into the queue
                        if ic <= t_node['child_num'] - 1:
                            w_list_trg.append(t.children[ic].value)
                            encoding = get_novel_positional_encoding(t.children[ic], ic, t_node)
                            if(t.children[ic].value != 0): 
                                last_append[i-1] = {"tree" : t.children[ic], "parent" : cur_index - 1, "child_index": ic, "tree_path" : t_path, 
                                                     "depth" : t_depth + 1, "child_num": len(t.children[ic].children), "encoding" : encoding}
                                if len(t.children[ic].children) > 0:
                                    queue_tree[i].append({"tree" : t.children[ic], "parent" : cur_index - 1, "child_index": ic, "tree_path":t_path, \
                                                        "depth" : t_depth + 1, "child_num": len(t.children[ic].children), "encoding":encoding})

                        if(ic + 1 < t_node['child_num']):
                            flag = 0

                        if len(queue_tree[i]) > max_index:
                            max_index = len(queue_tree[i])
                    if len(t_path) > max_w_len_path:
                        max_w_len_path = len(t_path)
                    if len(graphs_data[i-1]) > max_w_len:
                        max_w_len = len(graphs_data[i-1])
                    batch_w_list_trg.append(w_list_trg)
                    batch_w_list.append(t_path)

                dict_info = {
                    'batch_w_list' : batch_w_list[0],
                    'batch_w_list_trg' : batch_w_list_trg[0],
                    'graphs': graphs[0],
                    "graph_data":torch.tensor(graphs_data[0]),
                    "graph_depth":torch.tensor(graphs_data_depth[0]),
                    "graph_data_encoding":graphs_data_encoding[0]

                }

                if batch_w_list_trg[0] == [] and ic == 0:
                    print(ic)

                with open(os.path.join(path, str(cur_index)+'_'+str(ic)), 'wb') as f:
                    if dict_info =={}:
                        print(dict_info)
                    pickle.dump(dict_info, f)

                cur_index = cur_index + flag
                ic = 0 if flag == 1 else ic + 1