in baseline_model/data_utils/train_tree_encoder.py [0:0]
def test_tree(args, model, iterator, trg_pad_idx, device, smoothing, criterion, clip):
n_word_total, n_word_correct = 0, 0
epoch_loss = 0
batch_graph_tmp = None
model.eval()
model = model.module
with torch.set_grad_enabled(False):
for i, batch in enumerate(iterator):
trg = batch['trg']
graphs_asm = batch['graphs_asm']
src_len = batch['src_len']
batch_asm = dgl.batch(graphs_asm).to(device)
enc_src = model.gnn_asm(batch_asm)
src_mask = model.make_src_mask(enc_src.max(2)[0])
if args.graph_aug:
batch_graph_tmp = preprocessing_batch_tmp(src_len, graphs_asm, device).to(device)
enc_src = model.encoder(enc_src, src_mask, batch_graph_tmp)
batch_size = len(trg)
queue_tree = {}
graphs = []
graphs_data = []
graphs_data_depth = []
total_tree_num = [1 for i in range(0,batch_size)]
for i in range(1, batch_size+1):
queue_tree[i] = []
queue_tree[i].append({"tree" : trg[i-1], "parent": 0, "child_index": 1 , "tree_path":[], "depth": 0, "child_num":len(trg[i-1].children), "encoding":[], "predict":trg[i-1].value})
total_tree_num[i-1]+= len(trg[i-1].children)
trg[i-1].predict = trg[i-1].value
g = dgl.DGLGraph()
graphs.append(g)
graphs_data.append([])
graphs_data_depth.append([])
cur_index, max_index = 1,1
loss, ic = 0, 0
last_append = [None] * batch_size
while (cur_index <= max_index):
max_w_len = -1
max_w_len_path = -1
batch_w_list_trg = []
batch_w_list = []
flag = 1
t = [None] * batch_size
batch_graph_len_list = [0] * batch_size
graphs_tmp = [dgl.DGLGraph()] * batch_size
for i in range(1, batch_size+1):
w_list_trg = []
if (cur_index <= len(queue_tree[i])):
t_node = queue_tree[i][cur_index - 1]
t_encode = t_node["encoding"]
t_depth = t_node["depth"]
t[i-1] = t_node["tree"]
if ic == 0:
queue_tree[i][cur_index - 1]["tree_path"].append(t[i-1].predict)
t_path = queue_tree[i][cur_index - 1]["tree_path"].copy()
if ic == 0 and cur_index == 1:
graphs[i-1].add_nodes(1)
graphs[i-1].add_edges(t_node["parent"], cur_index - 1)
graphs_data[i-1].append(t[i-1].predict)
graphs_data_depth[i-1].append(t_depth)
elif (ic <= t_node['child_num'] - 1):
t_node_child = last_append[i-1]
graphs[i-1].add_nodes(1)
graphs[i-1].add_edges(t_node_child["parent"],len(queue_tree[i])-1)
graphs_data[i-1].append(t_node_child["tree"].predict)
graphs_data_depth[i-1].append(t_node_child["depth"])
# if it is not expanding all the children, add children into the queue
if ic <= t_node['child_num'] - 1:
w_list_trg.append(t[i-1].children[ic].value)
encoding = get_novel_positional_encoding(t[i-1].children[ic], ic, t_node)
if(t[i-1].children[ic].value != 0):
last_append[i-1] = {"tree" : t[i-1].children[ic], "parent" : cur_index - 1, "child_index": ic, "tree_path" : t_path,
"depth" : t_depth + 1, "child_num": len(t[i-1].children[ic].children), "encoding" : encoding}
if len(t[i-1].children[ic].children) > 0:
queue_tree[i].append({"tree" : t[i-1].children[ic], "parent" : cur_index - 1, "child_index": ic, "tree_path":t_path, \
"depth" : t_depth + 1, "child_num": len(t[i-1].children[ic].children), "encoding":encoding})
batch_graph_len_list[i-1] = len(graphs_data[i-1])
graphs_tmp[i-1] = graphs[i-1]
else:
batch_graph_len_list[i-1] = 0
graphs_tmp[i-1] = dgl.DGLGraph()
if(ic + 1 < t_node['child_num']):
flag = 0
if len(queue_tree[i]) > max_index:
max_index = len(queue_tree[i])
if len(t_path) > max_w_len_path:
max_w_len_path = len(t_path)
if len(graphs_data[i-1]) > max_w_len:
max_w_len = len(graphs_data[i-1])
batch_w_list_trg.append(w_list_trg)
batch_w_list.append(t_path)
trg_l = [trg_pad_idx for i in range(0, batch_size)]
w_list_len = []
in_ = torch.zeros((batch_size, args.output_dim, max_w_len_path), dtype=torch.long)
for i in range(batch_size):
annotation = torch.zeros([graphs[i].number_of_nodes(), args.hid_dim - args.depth_dim], dtype=torch.long)
depth_annotation = torch.zeros([graphs[i].number_of_nodes(), args.depth_dim], dtype=torch.long)
for idx in range(0,len(graphs_data[i])):
annotation[idx][graphs_data[i][idx]] = 1
depth_annotation[idx][graphs_data_depth[i][idx]] = 1
if len(batch_w_list_trg[i]) > 0 :
depth_annotation[cur_index-1][-1] = 1
graphs[i].ndata['annotation'] = torch.cat([annotation,depth_annotation],dim=1)
w_list_trg = batch_w_list_trg[i]
t_path = batch_w_list[i]
w_list_len.append(len(t_path)-1)
if len(w_list_trg) > 0 :
trg_l[i] = w_list_trg[0]
for j in range(len(t_path)):
in_[i][t_path[j]][j] = 1
in_[i][-ic-1][len(t_path)-1] = 1
in_ = in_.float().permute(0,2,1).cuda()
if args.graph_aug:
batch_graph_tmp = preprocessing_batch_tmp(batch_graph_len_list, graphs_tmp, device).to(device)
batch_graph = dgl.batch(graphs_tmp).to(device)
trg_in = model.gnn(batch_graph)
assert batch_graph_tmp.num_nodes() == trg_in.view(-1, args.hid_dim).shape[0], 'not match ast graph'
output_l = model.decoder(trg_in, in_, enc_src, src_mask, batch_graph=batch_graph_tmp)
output_l_list = []
for p in range(len(w_list_len)):
output_l_list.append(output_l[p][w_list_len[p]].view(1,-1))
output_l = torch.cat(output_l_list,dim=0).view(batch_size,-1)
output = torch.cat([output_l], dim=0)
trg_ = torch.tensor(trg_l).cuda()
output_predict_list = output.argmax(1).tolist()
for p, elem in enumerate(output_predict_list):
if t[p] is not None and (len(t[p].children) > ic ):
# The 1st node is root node.
if cur_index < 2:
t[p].children[ic].predict = trg_[p]
else:
t[p].children[ic].predict = elem
loss_itr, n_correct, n_word = cal_performance(
output, trg_, trg_pad_idx, smoothing=smoothing)
loss += loss_itr
n_word_total += n_word
n_word_correct += n_correct
cur_index = cur_index + flag
ic = 0 if flag == 1 else ic + 1
epoch_loss += loss.item()
loss_per_word = epoch_loss/n_word_total
accuracy = n_word_correct/n_word_total
return loss_per_word, accuracy