in baseline_model/data_utils/train_tree_encoder.py [0:0]
def train_eval_tree(args, model, iterator, optimizer, device, \
criterion, dec_seq_length, train_flag=True):
if train_flag:
mode = 'train'
model.train()
else:
mode = 'valid'
model.eval()
n_word_total, n_word_correct = 0, 0
epoch_loss = 0
sample_len = args.sample_len
batch_graph_tmp = None
batch_size = args.bsz
if args.dist_gpu == True:
model = model.module
with torch.set_grad_enabled(train_flag):
for i, batch in enumerate(iterator):
dict_info = batch['dict_info']
batch_size = len(batch['trg'])
id_elem = batch['id']
graphs_asm = batch['graphs_asm']
src_len = batch['src_len']
batch_asm = dgl.batch(graphs_asm).to(device)
enc_src = model.gnn_asm(batch_asm)
src_mask = model.make_src_mask(enc_src.max(2)[0])
if args.graph_aug:
batch_graph_tmp = preprocessing_batch_tmp(src_len, graphs_asm, device).to(device)
enc_src = model.encoder(enc_src, src_mask, batch_graph_tmp)
cur_index, max_index = 1, 1
loss = 0
ic = 0
cur_index_batch = [1] * batch_size
batch_nodes_num = [None] * batch_size
for aa in range(0, batch_size):
batch_nodes_num[aa] = len([i for i in dict_info[aa].keys() if '_0' in i])
if batch_nodes_num[aa] > sample_len:
rand_int = np.random.randint(-sample_len*2+1, batch_nodes_num[aa])
if rand_int < 1:
cur_index_batch[aa] = 1
elif rand_int > batch_nodes_num[aa] - sample_len:
cur_index_batch[aa] = batch_nodes_num[aa] - sample_len
else:
cur_index_batch[aa] = rand_int
max_index = max(batch_nodes_num)
graphs = [None] * batch_size
graphs_data = [None] * batch_size
graphs_data_depth = [None] * batch_size
graphs_data_encoding = [None] * batch_size
if max_index > sample_len:
max_index = sample_len
while (cur_index <= max_index):
flag = 1
max_w_len_path = -1
batch_w_list_trg = [None] * batch_size
batch_w_list = [None] * batch_size
batch_len_trg = [0] * batch_size
batch_graph_len_list = [0] * batch_size
for aa in range(0, batch_size):
path = os.path.join(args.cache_path, str(id_elem[aa]), str(cur_index_batch[aa])+'_'+ str(ic))
path_next = os.path.join(args.cache_path, str(id_elem[aa]), str(cur_index_batch[aa])+'_'+ str(ic+1))
if path in dict_info[aa].keys():
batch_w_list[aa] = dict_info[aa][path]['batch_w_list']
batch_len_trg[aa] = len(dict_info[aa][path]['batch_w_list'])
batch_w_list_trg[aa] = dict_info[aa][path]['batch_w_list_trg']
graphs[aa] = dict_info[aa][path]['graphs'].to(device)
graphs_data[aa] = dict_info[aa][path]['graph_data'].to(device, non_blocking=True)
batch_graph_len_list[aa] = len(graphs_data[aa])
graphs_data_depth[aa] = dict_info[aa][path]['graph_depth'].to(device, non_blocking=True)
else:
graphs_data[aa] = None
graphs[aa] = dgl.DGLGraph().to(device)
batch_graph_len_list[aa] = 0
if path_next in dict_info[aa].keys():
flag = 0
max_w_len_path = max(batch_len_trg)
in_ = torch.zeros((batch_size, args.output_dim, max_w_len_path), dtype=torch.long)
trg_list = [model.trg_pad_idx for i in range(0, batch_size)]
w_list_len = []
for i in range(batch_size):
annotation = torch.zeros([graphs[i].number_of_nodes(), args.hid_dim - args.depth_dim], dtype=torch.long).cuda()
depth_annotation = torch.zeros([graphs[i].number_of_nodes(), args.depth_dim], dtype=torch.long).cuda()
if (batch_w_list_trg[i] is not None) and len(batch_w_list_trg[i]) > 0 :
annotation.scatter_(1, graphs_data[i].view(-1,1), value=torch.tensor(1))
depth_annotation.scatter_(1, graphs_data_depth[i].view(-1,1), value=torch.tensor(1))
depth_annotation[cur_index_batch[i]-1][-1] = 1
graphs[i].ndata['annotation'] = torch.cat([annotation,depth_annotation],dim=1).float()
w_list_trg = batch_w_list_trg[i]
t_path = batch_w_list[i]
if t_path is not None:
w_list_len.append(len(t_path)-1)
else:
w_list_len.append(0)
if (w_list_trg is not None) and len(w_list_trg) > 0 :
trg_list[i] = w_list_trg[0]
for j in range(len(t_path)):
in_[i][t_path[j]][j] = 1
in_[i][-ic-1][len(t_path)-1] = 1
in_ = in_.float().permute(0,2,1).cuda()
batch_graph = dgl.batch(graphs).to(device)
trg_in = model.gnn(batch_graph)
if args.graph_aug:
batch_graph_tmp = preprocessing_batch_tmp(batch_graph_len_list, graphs, device).to(device)
assert batch_graph_tmp.num_nodes() == trg_in.view(-1, args.hid_dim).shape[0], 'not match ast graph'
output = model.decoder(trg_in, in_, enc_src, src_mask, batch_graph=batch_graph_tmp)
output_list = []
for p in range(len(w_list_len)):
output_list.append(output[p][w_list_len[p]].view(1,-1))
output = torch.cat(output_list,dim=0).view(batch_size,-1)
output = torch.cat([output], dim=0)
trg_ = torch.tensor(trg_list).cuda()
loss_itr, n_correct, n_word = cal_performance(
output, trg_, model.trg_pad_idx, smoothing=args.label_smoothing)
loss += loss_itr
n_word_total += n_word
n_word_correct += n_correct
cur_index = cur_index + flag
cur_index_batch = [x + flag for x in cur_index_batch]
ic = 0 if flag == 1 else ic + 1
if train_flag:
optimizer.optimizer.zero_grad()
loss.backward()
if args.dist_gpu:
for param in model.parameters():
if param.requires_grad and param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= args.n_dist_gpu
else:
args.summary.add_scalar(mode + '/loss', loss.item())
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
epoch_loss += loss.item()
loss_per_word = epoch_loss/n_word_total
accuracy = n_word_correct/n_word_total
return loss_per_word, accuracy