def evaluation()

in baseline_model/data_utils/train_tree_encoder_v2.py [0:0]


def evaluation(opt, iterator, encoder, decoder_l,decoder_r, attention_decoder, criterion, using_gpu, tree_node_gen, teaching_force =True):
    epoch_loss = 0
    encoder.eval()
    decoder_r.eval()
    decoder_l.eval()
    attention_decoder.eval()
    TreeRoot = []
    TrgRoot  = []
    for it, batch in enumerate(iterator):
        enc_batch = batch.src # .transpose(1,0)
        dec_tree_batch  = batch.trg
        enclen = batch.enclen
        enc_max_len  = opt.enc_seq_length #enc_batch.size(1)
        # enc_outputs = torch.zeros((enc_batch.size(0), enc_max_len, encoder.hidden_size), requires_grad=True)
        enc_outputs = torch.zeros((len(enc_batch), enc_max_len, encoder.hidden_size), requires_grad=True)
        if using_gpu:
            enc_outputs = enc_outputs.cuda()
        enc_s = {}
        for j in range(opt.enc_seq_length + 1):
            enc_s[j] = {}

        dec_s = {}
        for i in range(opt.dec_seq_length + 1):
            dec_s[i] = {}
            for j in range(3):
                dec_s[i][j] = {}

        for i in range(1, 3):
            enc_s[0][i] = torch.zeros((opt.batch_size, opt.rnn_size), dtype=torch.float, requires_grad=True)
            if using_gpu:
                enc_s[0][i] = enc_s[0][i].cuda()

        # TODO:change this part
        for i in range(enc_max_len):
            enc_s[i+1][1], enc_s[i+1][2] = encoder(enc_batch, i, enc_s[i][1], enc_s[i][2])
            enc_outputs[:, i, :] = enc_s[i+1][2]
        # tree decode
        queue_tree = {}
        TreeRootGen = {}
        TreeNodeCurrent = {}
        for i in range(1, opt.batch_size+1):
            if tree_node_gen:
                TreeRootGen[i] = Tree(StmtID)
                TreeNodeCurrent[i] = []
                TreeNodeCurrent[i].append(TreeRootGen[i])
            # pdb.set_trace()
            queue_tree[i]  = []
            queue_tree[i].append({"tree" : dec_tree_batch[i-1], "parent": 0, "child_index": 1})

        loss = 0
        cur_index, max_index = 1,1
        dec_batch = {}
        dec_batch_trg = {}

        while (cur_index <= max_index):
            # build dec_batch for cur_index
            max_w_len = -1
            batch_w_list = []
            batch_w_list_trg = []
            for i in range(1, opt.batch_size+1):
                w_list = []
                w_list_trg = []
                if (cur_index <= len(queue_tree[i])):
                    t = queue_tree[i][cur_index - 1]["tree"]
                    # for ic in range (t.num_children):
                    for ic in range (len(t.children)):
                        w_list.append(t.value)
                        w_list_trg.append(t.children[ic].value)
                        if(tree_node_gen):
                            NewTreeNode = Tree(Dummy)
                            NewTreeNode.parent = TreeNodeCurrent[i][cur_index - 1]
                            TreeNodeCurrent[i][cur_index - 1].children.append(NewTreeNode)
                        if(t.children[ic].value != 0):
                            if(tree_node_gen):
                                TreeNodeCurrent[i].append(NewTreeNode)
                            queue_tree[i].append({"tree" : t.children[ic], "parent" : cur_index, "child_index": ic })
                    if len(queue_tree[i]) > max_index:
                        max_index = len(queue_tree[i])
                if len(w_list) > max_w_len:
                    max_w_len = len(w_list)
                batch_w_list.append(w_list)
                batch_w_list_trg.append(w_list_trg)

            dec_batch[cur_index] = torch.zeros((opt.batch_size, 2), dtype=torch.long)
            dec_batch_trg[cur_index] = torch.zeros((opt.batch_size, 2), dtype=torch.long)
            for i in range(opt.batch_size):
                w_list = batch_w_list[i]
                w_list_trg = batch_w_list_trg[i]
                if len(w_list) > 0:
                    for j in range(len(w_list)):
                        dec_batch[cur_index][i][j] = w_list[j]
                        dec_batch_trg[cur_index][i][j] = w_list_trg[j]

            # initialize first decoder unit hidden state (zeros)
            if using_gpu:
                dec_batch[cur_index] = dec_batch[cur_index].cuda()
                dec_batch_trg[cur_index] = dec_batch_trg[cur_index].cuda()

            # initialize using encoding results
            for j in range(1, 3):
                dec_s[cur_index][0][j] = torch.zeros((opt.batch_size, opt.rnn_size), dtype=torch.float, requires_grad=True)
                if using_gpu:
                    dec_s[cur_index][0][j] = dec_s[cur_index][0][j].cuda()

            #dec_s  1: cur_index 2: child index 3. h (1) or s (2)
            if cur_index == 1:
                for i in range(opt.batch_size):
                    dec_s[1][0][1][i, :] = enc_s[enclen[i]][1][i, :]
                    dec_s[1][0][2][i, :] = enc_s[enclen[i]][2][i, :]
            else:
                for i in range(1, opt.batch_size+1):
                    if (cur_index <= len(queue_tree[i])):
                        par_index = queue_tree[i][cur_index - 1]["parent"]
                        child_index = queue_tree[i][cur_index - 1]["child_index"]
                        dec_s[cur_index][0][1][i-1,:] = dec_s[par_index][child_index][1][i-1,:]
                        dec_s[cur_index][0][2][i-1,:] = dec_s[par_index][child_index][2][i-1,:]
            #loss = 0
            #prev_c, prev_h = dec_s[cur_index, 0, 0,:,:], dec_s[cur_index, 0, 1,:,:]
            #pred_matrix = np.ndarray((20, dec_batch[cur_index].size(1)-1), dtype=object)
            parent_h = dec_s[cur_index][0][2]

            # left-right decoder style
            dec_s[cur_index][1][1], dec_s[cur_index][1][2] =decoder_l(dec_batch[cur_index][:,0], dec_s[cur_index][0][1], dec_s[cur_index][0][2], parent_h)
            pred_l = attention_decoder(enc_outputs,dec_s[cur_index][1][2])
            loss += criterion(pred_l, dec_batch_trg[cur_index][:,0])

            dec_s[cur_index][2][1],dec_s[cur_index][2][2] = decoder_r(dec_batch[cur_index][:,1], dec_s[cur_index][0][1], dec_s[cur_index][0][2], parent_h)
            pred_r = attention_decoder(enc_outputs,dec_s[cur_index][2][2])
            loss += criterion(pred_r, dec_batch_trg[cur_index][:,1])

            # pdb.set_trace()
            max_pred_l = torch.max(pred_l,1)[1]
            max_pred_r = torch.max(pred_r,1)[1]
            # pdb.set_trace()
            if(tree_node_gen):
                for i in range(1,opt.batch_size + 1):
                    try:
                        if(cur_index <= len(TreeNodeCurrent[i]) and len(TreeNodeCurrent[i][cur_index - 1 ].children) != 0 ):
                            TreeNodeCurrent[i][cur_index - 1 ].children[0].value = max_pred_l[i-1].item()
                            if len(TreeNodeCurrent[i][cur_index - 1 ].children) > 1:
                                TreeNodeCurrent[i][cur_index - 1 ].children[1].value = max_pred_r[i-1].item()
                    except:
                        pdb.set_trace()
                # pdb.set_trace()
        # for i in range(dec_batch[cur_index].size(1)):
            # pdb.set_trace()
            #     #print(i)
            #     # pdb.set_trace()
            #     dec_s[cur_index][i+1][1], dec_s[cur_index][i+1][2] = decoder(dec_batch[cur_index][:,i], dec_s[cur_index][i][1], dec_s[cur_index][i][2], parent_h)
            #     pred = attention_decoder(enc_outputs, dec_s[cur_index][i+1][2])
            #     loss += criterion(pred, dec_batch[cur_index][:,i+1])

            cur_index = cur_index + 1
        # pdb.set_trace()
        TreeRoot.append(TreeRootGen)
        TrgRoot.append(batch.trg)
        loss = loss / opt.batch_size
        epoch_loss += loss.item()
    return epoch_loss / len(iterator), TrgRoot , TreeRoot