def test_tree()

in baseline_model/data_utils/train_tree_encoder.py [0:0]


def test_tree(args, model, iterator, trg_pad_idx, device, smoothing, criterion, clip):
    n_word_total, n_word_correct = 0, 0
    epoch_loss = 0
    batch_graph_tmp = None
    model.eval()
    model = model.module
    with torch.set_grad_enabled(False):
        for i, batch in enumerate(iterator):

            trg = batch['trg']
            graphs_asm = batch['graphs_asm']
            src_len = batch['src_len']

            batch_asm = dgl.batch(graphs_asm).to(device)
            enc_src = model.gnn_asm(batch_asm)
            src_mask = model.make_src_mask(enc_src.max(2)[0])            
            if args.graph_aug:
                batch_graph_tmp = preprocessing_batch_tmp(src_len, graphs_asm, device).to(device)

            enc_src = model.encoder(enc_src, src_mask, batch_graph_tmp)
            batch_size = len(trg)

            queue_tree = {}
            graphs = []
            graphs_data = []
            graphs_data_depth = []
            total_tree_num = [1 for i in range(0,batch_size)]
            for i in range(1, batch_size+1):
                queue_tree[i] = []
                queue_tree[i].append({"tree" : trg[i-1], "parent": 0, "child_index": 1 , "tree_path":[], "depth": 0, "child_num":len(trg[i-1].children), "encoding":[], "predict":trg[i-1].value})
                total_tree_num[i-1]+= len(trg[i-1].children)
                trg[i-1].predict = trg[i-1].value
                g = dgl.DGLGraph()
                graphs.append(g)
                graphs_data.append([])
                graphs_data_depth.append([])

            cur_index, max_index = 1,1
            loss, ic = 0, 0
            last_append = [None] * batch_size

            while (cur_index <= max_index):
                max_w_len = -1
                max_w_len_path = -1
                batch_w_list_trg = []
                batch_w_list = []
                flag = 1
                t = [None] * batch_size
                batch_graph_len_list = [0] * batch_size
                graphs_tmp = [dgl.DGLGraph()] * batch_size

                for i in range(1, batch_size+1):
                    w_list_trg = []
                    if (cur_index <= len(queue_tree[i])):
                        t_node = queue_tree[i][cur_index - 1]
                        t_encode = t_node["encoding"]
                        t_depth = t_node["depth"]
                        t[i-1] = t_node["tree"]
                        if ic == 0:
                            queue_tree[i][cur_index - 1]["tree_path"].append(t[i-1].predict)
                        t_path = queue_tree[i][cur_index - 1]["tree_path"].copy()

                        if ic == 0 and cur_index == 1:
                            graphs[i-1].add_nodes(1)
                            graphs[i-1].add_edges(t_node["parent"], cur_index - 1)
                            graphs_data[i-1].append(t[i-1].predict)
                            graphs_data_depth[i-1].append(t_depth)
                        elif (ic <= t_node['child_num'] - 1):
                            t_node_child = last_append[i-1]
                            graphs[i-1].add_nodes(1)
                            graphs[i-1].add_edges(t_node_child["parent"],len(queue_tree[i])-1)
                            graphs_data[i-1].append(t_node_child["tree"].predict)
                            graphs_data_depth[i-1].append(t_node_child["depth"])

                        # if it is not expanding all the children, add children into the queue
                        if ic <= t_node['child_num'] - 1:
                            w_list_trg.append(t[i-1].children[ic].value)
                            encoding = get_novel_positional_encoding(t[i-1].children[ic], ic, t_node)
                            if(t[i-1].children[ic].value != 0):
                                last_append[i-1] = {"tree" : t[i-1].children[ic], "parent" : cur_index - 1, "child_index": ic, "tree_path" : t_path, 
                                                     "depth" : t_depth + 1, "child_num": len(t[i-1].children[ic].children), "encoding" : encoding}
                                if len(t[i-1].children[ic].children) > 0:
                                    queue_tree[i].append({"tree" : t[i-1].children[ic], "parent" : cur_index - 1, "child_index": ic, "tree_path":t_path, \
                                                        "depth" : t_depth + 1, "child_num": len(t[i-1].children[ic].children), "encoding":encoding})
                            batch_graph_len_list[i-1] = len(graphs_data[i-1])
                            graphs_tmp[i-1] = graphs[i-1]
                        else:
                            batch_graph_len_list[i-1] = 0
                            graphs_tmp[i-1] = dgl.DGLGraph()

                        if(ic + 1 < t_node['child_num']):
                            flag = 0

                        if len(queue_tree[i]) > max_index:
                            max_index = len(queue_tree[i])
                    if len(t_path) > max_w_len_path:
                        max_w_len_path = len(t_path)
                    if len(graphs_data[i-1]) > max_w_len:
                        max_w_len = len(graphs_data[i-1])
                    batch_w_list_trg.append(w_list_trg)
                    batch_w_list.append(t_path)

                trg_l = [trg_pad_idx for i in range(0, batch_size)]
                w_list_len = []

                in_ = torch.zeros((batch_size, args.output_dim, max_w_len_path), dtype=torch.long)
                for i in range(batch_size):
                    annotation = torch.zeros([graphs[i].number_of_nodes(), args.hid_dim - args.depth_dim], dtype=torch.long)
                    depth_annotation = torch.zeros([graphs[i].number_of_nodes(), args.depth_dim], dtype=torch.long)
                    for idx in range(0,len(graphs_data[i])):
                        annotation[idx][graphs_data[i][idx]] = 1
                        depth_annotation[idx][graphs_data_depth[i][idx]] = 1

                    if len(batch_w_list_trg[i]) > 0 :
                        depth_annotation[cur_index-1][-1] = 1
                    graphs[i].ndata['annotation'] = torch.cat([annotation,depth_annotation],dim=1)

                    w_list_trg = batch_w_list_trg[i]
                    t_path = batch_w_list[i]
                    w_list_len.append(len(t_path)-1)
                    if len(w_list_trg) > 0 :
                        trg_l[i] = w_list_trg[0]
                        for j in range(len(t_path)):
                            in_[i][t_path[j]][j] = 1
                        in_[i][-ic-1][len(t_path)-1] = 1

                in_ = in_.float().permute(0,2,1).cuda()
                if args.graph_aug:
                    batch_graph_tmp = preprocessing_batch_tmp(batch_graph_len_list, graphs_tmp, device).to(device)
                batch_graph = dgl.batch(graphs_tmp).to(device)
                trg_in = model.gnn(batch_graph)
                assert batch_graph_tmp.num_nodes() == trg_in.view(-1, args.hid_dim).shape[0], 'not match ast graph'

                output_l = model.decoder(trg_in, in_, enc_src, src_mask, batch_graph=batch_graph_tmp)

                output_l_list = []
                for p in range(len(w_list_len)):
                    output_l_list.append(output_l[p][w_list_len[p]].view(1,-1))

                output_l = torch.cat(output_l_list,dim=0).view(batch_size,-1)
                output = torch.cat([output_l], dim=0)
                trg_ = torch.tensor(trg_l).cuda()
                output_predict_list = output.argmax(1).tolist()
                for p, elem in enumerate(output_predict_list):
                    if t[p] is not None and (len(t[p].children) > ic ):
                        # The 1st node is root node.
                        if cur_index < 2:
                            t[p].children[ic].predict = trg_[p]
                        else:
                            t[p].children[ic].predict = elem
                loss_itr, n_correct, n_word = cal_performance(
                    output, trg_, trg_pad_idx, smoothing=smoothing)

                loss += loss_itr
                n_word_total += n_word
                n_word_correct += n_correct
                cur_index = cur_index + flag
                ic = 0 if flag == 1 else ic + 1

            epoch_loss += loss.item()
        loss_per_word = epoch_loss/n_word_total
        accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy