def train_eval_tree()

in baseline_model/data_utils/train_tree_encoder.py [0:0]


def train_eval_tree(args, model, iterator, optimizer, device, \
        criterion, dec_seq_length, train_flag=True):

    if train_flag:
        mode = 'train'
        model.train()
    else:
        mode = 'valid'
        model.eval()
    n_word_total, n_word_correct = 0, 0
    epoch_loss = 0

    sample_len = args.sample_len
    batch_graph_tmp = None
    batch_size = args.bsz 
    if args.dist_gpu == True:
        model = model.module
    with torch.set_grad_enabled(train_flag):
        for i, batch in enumerate(iterator):
            dict_info = batch['dict_info']
            batch_size = len(batch['trg'])
            id_elem = batch['id']
            graphs_asm = batch['graphs_asm']
            src_len = batch['src_len']

            batch_asm = dgl.batch(graphs_asm).to(device)
            enc_src = model.gnn_asm(batch_asm)
            src_mask = model.make_src_mask(enc_src.max(2)[0])            
            if args.graph_aug:
                batch_graph_tmp = preprocessing_batch_tmp(src_len, graphs_asm, device).to(device)

            enc_src = model.encoder(enc_src, src_mask, batch_graph_tmp)
            cur_index, max_index = 1, 1
            loss = 0
            ic = 0
            cur_index_batch = [1] * batch_size

            batch_nodes_num = [None] * batch_size
            for aa in range(0, batch_size):
                batch_nodes_num[aa] = len([i for i in dict_info[aa].keys()  if '_0' in i])
                if batch_nodes_num[aa] > sample_len:
                    rand_int = np.random.randint(-sample_len*2+1, batch_nodes_num[aa])
                    if rand_int < 1:
                        cur_index_batch[aa] = 1
                    elif rand_int > batch_nodes_num[aa] - sample_len:
                        cur_index_batch[aa] = batch_nodes_num[aa] - sample_len
                    else:
                        cur_index_batch[aa] = rand_int

            max_index = max(batch_nodes_num)
            graphs = [None] * batch_size
            graphs_data = [None] * batch_size
            graphs_data_depth = [None] * batch_size
            graphs_data_encoding = [None] * batch_size

            if max_index > sample_len:
                max_index = sample_len

            while (cur_index <= max_index):
                flag             =  1
                max_w_len_path   = -1
                batch_w_list_trg = [None] * batch_size 
                batch_w_list     = [None] * batch_size
                batch_len_trg    = [0] * batch_size
                batch_graph_len_list = [0] * batch_size

                for aa in range(0, batch_size):
                    path = os.path.join(args.cache_path, str(id_elem[aa]), str(cur_index_batch[aa])+'_'+ str(ic))
                    path_next = os.path.join(args.cache_path, str(id_elem[aa]), str(cur_index_batch[aa])+'_'+ str(ic+1))
                    if path in dict_info[aa].keys():

                        batch_w_list[aa]      = dict_info[aa][path]['batch_w_list']
                        batch_len_trg[aa]     = len(dict_info[aa][path]['batch_w_list'])
                        batch_w_list_trg[aa]  = dict_info[aa][path]['batch_w_list_trg']
                        graphs[aa]            = dict_info[aa][path]['graphs'].to(device)
                        graphs_data[aa]       = dict_info[aa][path]['graph_data'].to(device, non_blocking=True)
                        batch_graph_len_list[aa]  = len(graphs_data[aa])
                        graphs_data_depth[aa] = dict_info[aa][path]['graph_depth'].to(device, non_blocking=True)
                    else:
                        graphs_data[aa] = None
                        graphs[aa] = dgl.DGLGraph().to(device)
                        batch_graph_len_list[aa] = 0
    
                    if path_next in dict_info[aa].keys(): 
                        flag = 0

                max_w_len_path = max(batch_len_trg)
                
                in_ = torch.zeros((batch_size, args.output_dim, max_w_len_path), dtype=torch.long)
                trg_list = [model.trg_pad_idx for i in range(0, batch_size)]
                w_list_len = []

                for i in range(batch_size):
                    annotation = torch.zeros([graphs[i].number_of_nodes(), args.hid_dim - args.depth_dim], dtype=torch.long).cuda()
                    depth_annotation = torch.zeros([graphs[i].number_of_nodes(), args.depth_dim], dtype=torch.long).cuda()

                    if (batch_w_list_trg[i] is not None) and len(batch_w_list_trg[i]) > 0 :
                        annotation.scatter_(1, graphs_data[i].view(-1,1), value=torch.tensor(1))
                        depth_annotation.scatter_(1, graphs_data_depth[i].view(-1,1), value=torch.tensor(1))
                        depth_annotation[cur_index_batch[i]-1][-1] = 1
                        graphs[i].ndata['annotation'] = torch.cat([annotation,depth_annotation],dim=1).float()

                    w_list_trg = batch_w_list_trg[i]
                    t_path = batch_w_list[i]
                    if t_path is not None:
                        w_list_len.append(len(t_path)-1)
                    else:
                        w_list_len.append(0)
                    if (w_list_trg is not None) and len(w_list_trg) > 0 :
                        trg_list[i] = w_list_trg[0]
                        for j in range(len(t_path)):
                            in_[i][t_path[j]][j] = 1
                        in_[i][-ic-1][len(t_path)-1] = 1
                        
                in_ = in_.float().permute(0,2,1).cuda()
                batch_graph = dgl.batch(graphs).to(device)
                trg_in = model.gnn(batch_graph) 
                if args.graph_aug:
                    batch_graph_tmp = preprocessing_batch_tmp(batch_graph_len_list, graphs, device).to(device)
                    assert batch_graph_tmp.num_nodes() == trg_in.view(-1, args.hid_dim).shape[0], 'not match ast graph'

                output = model.decoder(trg_in, in_, enc_src, src_mask, batch_graph=batch_graph_tmp)

                output_list = []
                for p in range(len(w_list_len)):
                    output_list.append(output[p][w_list_len[p]].view(1,-1))

                output = torch.cat(output_list,dim=0).view(batch_size,-1)
                output = torch.cat([output], dim=0)
                trg_ = torch.tensor(trg_list).cuda()

                loss_itr, n_correct, n_word = cal_performance(
                    output, trg_, model.trg_pad_idx, smoothing=args.label_smoothing)
                loss += loss_itr
                n_word_total += n_word
                n_word_correct += n_correct
                cur_index = cur_index + flag
                cur_index_batch = [x + flag for x in cur_index_batch] 
                ic = 0 if flag == 1 else ic + 1

            if train_flag:
                optimizer.optimizer.zero_grad()
                loss.backward()
                if args.dist_gpu:
                    for param in model.parameters():
                        if param.requires_grad and param.grad is not None:
                            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                            param.grad.data /= args.n_dist_gpu
                else:
                    args.summary.add_scalar(mode + '/loss', loss.item())
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
                optimizer.step()
            epoch_loss += loss.item()

    loss_per_word = epoch_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy