in baseline_model/data_utils/train_tree_encoder_v2.py [0:0]
def eval_training(opt, iterator, encoder, decoder_l,decoder_r, attention_decoder, encoder_optimizer, decoder_optimizer_l, decoder_optimizer_r, attention_decoder_optimizer, criterion, using_gpu):
epoch_loss = 0
encoder.train()
decoder_r.train()
decoder_l.train()
attention_decoder.train()
for it, batch in enumerate(iterator):
# print(it)
encoder_optimizer.zero_grad()
decoder_optimizer_l.zero_grad()
decoder_optimizer_r.zero_grad()
attention_decoder_optimizer.zero_grad()
enc_batch = batch.src #.transpose(1,0)
dec_tree_batch = batch.trg
enclen = batch.enclen
# enc_max_len = enc_batch.size(1) #batch.src.shape[1]
enc_max_len = opt.enc_seq_length #enc_batch.size(1)
enc_outputs = torch.zeros((len(enc_batch), enc_max_len, encoder.hidden_size), requires_grad=True)
if using_gpu:
enc_outputs = enc_outputs.cuda()
enc_s = {}
for j in range(opt.enc_seq_length + 1):
enc_s[j] = {}
dec_s = {}
for i in range(opt.dec_seq_length + 1):
dec_s[i] = {}
for j in range(3):
dec_s[i][j] = {}
for i in range(1, 3):
enc_s[0][i] = torch.zeros((opt.batch_size, opt.rnn_size), dtype=torch.float, requires_grad=True)
if using_gpu:
enc_s[0][i] = enc_s[0][i].cuda()
# pdb.set_trace()
# TODO:change this part
# import time
# start = time.time()
for i in range(enc_max_len):
enc_s[i+1][1], enc_s[i+1][2] = encoder(enc_batch, i, enc_s[i][1], enc_s[i][2])
enc_outputs[:, i, :] = enc_s[i+1][2]
# end = time.time()
# print("time encoding: " + str(end-start) )
# tree decode
queue_tree = {}
for i in range(1, opt.batch_size+1):
queue_tree[i] = []
queue_tree[i].append({"tree" : dec_tree_batch[i-1], "parent": 0, "child_index": 1})
loss = 0
cur_index, max_index = 1,1
dec_batch = {}
dec_batch_trg = {}
#print(queue_tree[1][0]["tree"].to_string());exit()
while (cur_index <= max_index):
#print(cur_index)
# build dec_batch for cur_index
max_w_len = -1
batch_w_list = []
batch_w_list_trg = []
# pdb.set_trace()
for i in range(1, opt.batch_size+1):
w_list = []
w_list_trg = []
if (cur_index <= len(queue_tree[i])):
t = queue_tree[i][cur_index - 1]["tree"]
for ic in range (len(t.children)):
w_list.append(t.value)
w_list_trg.append(t.children[ic].value)
if(t.children[ic].value != 0):
queue_tree[i].append({"tree" : t.children[ic], "parent" : cur_index, "child_index": ic })
# else:
# w_list.append(t.children[ic])
if len(queue_tree[i]) > max_index:
max_index = len(queue_tree[i])
if len(w_list) > max_w_len:
max_w_len = len(w_list)
batch_w_list.append(w_list)
batch_w_list_trg.append(w_list_trg)
# if(cur_index == 146):
# pdb.set_trace()
# dec_batch[cur_index] = torch.zeros((opt.batch_size, max_w_len), dtype=torch.long)
dec_batch[cur_index] = torch.zeros((opt.batch_size,2), dtype=torch.long)
dec_batch_trg[cur_index] = torch.zeros((opt.batch_size,2), dtype=torch.long)
for i in range(opt.batch_size):
w_list = batch_w_list[i]
w_list_trg = batch_w_list_trg[i]
if len(w_list) > 0:
for j in range(len(w_list)):
dec_batch[cur_index][i][j] = w_list[j]
dec_batch_trg[cur_index][i][j] = w_list_trg[j]
# if cur_index == 1:
# dec_batch[cur_index][i][0] = 0
# dec_batch[cur_index][i][len(w_list) ] = 1
# print(dec_batch[cur_index])
# initialize first decoder unit hidden state (zeros)
# try:
# if cur_index == 2:
# pdb.set_trace()
# print(dec_batch)
if using_gpu:
dec_batch[cur_index] = dec_batch[cur_index].cuda()
dec_batch_trg[cur_index] = dec_batch_trg[cur_index].cuda()
# except:
# initialize using encoding results
# print(cur_index)
for j in range(1, 3):
dec_s[cur_index][0][j] = torch.zeros((opt.batch_size, opt.rnn_size), dtype=torch.float, requires_grad=True)
if using_gpu:
dec_s[cur_index][0][j] = dec_s[cur_index][0][j].cuda()
#dec_s 1: cur_index 2: child index 3. h (1) or s (2)
if cur_index == 1:
for i in range(opt.batch_size):
# dec_s[1][0][1][i, :] = enc_s[enc_max_len][1][i, :]
# dec_s[1][0][2][i, :] = enc_s[enc_max_len][2][i, :]
try:
dec_s[1][0][1][i, :] = enc_s[enclen[i]][1][i, :]
dec_s[1][0][2][i, :] = enc_s[enclen[i]][2][i, :]
except:
pdb.set_trace()
else:
# pdb.set_trace()
for i in range(1, opt.batch_size+1):
if (cur_index <= len(queue_tree[i])):
par_index = queue_tree[i][cur_index - 1]["parent"]
child_index = queue_tree[i][cur_index - 1]["child_index"]
#print("parent child")
#print(par_index)
#print(child_index)
# if i == 1:
# pdb.set_trace()
dec_s[cur_index][0][1][i-1,:] = dec_s[par_index][child_index][1][i-1,:]
dec_s[cur_index][0][2][i-1,:] = dec_s[par_index][child_index][2][i-1,:]
#loss = 0
#prev_c, prev_h = dec_s[cur_index, 0, 0,:,:], dec_s[cur_index, 0, 1,:,:]
#pred_matrix = np.ndarray((20, dec_batch[cur_index].size(1)-1), dtype=object)
parent_h = dec_s[cur_index][0][2]
# pdb.set_trace()
try:
dec_s[cur_index][1][1], dec_s[cur_index][1][2] =decoder_l(dec_batch[cur_index][:,0], dec_s[cur_index][0][1], dec_s[cur_index][0][2], parent_h)
# pdb.set_trace()
except:
pdb.set_trace()
pred_l = attention_decoder(enc_outputs,dec_s[cur_index][1][2])
# pdb.set_trace()
loss += criterion(pred_l, dec_batch_trg[cur_index][:,0])
try:
dec_s[cur_index][2][1],dec_s[cur_index][2][2] = decoder_r(dec_batch[cur_index][:,1], dec_s[cur_index][0][1], dec_s[cur_index][0][2], parent_h)
except:
pdb.set_trace()
pred_r = attention_decoder(enc_outputs,dec_s[cur_index][2][2])
loss += criterion(pred_r, dec_batch_trg[cur_index][:,1])
# pdb.set_trace()
# for i in range(dec_batch[cur_index].size(1) - 1):
# #print(i)
# # pdb.set_trace()
# dec_s[cur_index][i+1][1], dec_s[cur_index][i+1][2] = decoder(dec_batch[cur_index][:,i], dec_s[cur_index][i][1], dec_s[cur_index][i][2], parent_h)
# pred = attention_decoder(enc_outputs, dec_s[cur_index][i+1][2])
# loss += criterion(pred, dec_batch[cur_index][:,i+1])
cur_index = cur_index + 1
#input_string = [form_manager.get_idx_symbol(int(p)) for p in enc_batch[0,:]]
#print("===========\n")
#print("input string: {}\n".format(input_string))
#print("predicted string: {}\n".format(pred_matrix[0,:]))
#print("===========\n")
# pdb.set_trace()
loss = loss / opt.batch_size
loss.backward()
torch.nn.utils.clip_grad_value_(encoder.parameters(),opt.grad_clip)
torch.nn.utils.clip_grad_value_(decoder_l.parameters(),opt.grad_clip)
torch.nn.utils.clip_grad_value_(decoder_r.parameters(),opt.grad_clip)
torch.nn.utils.clip_grad_value_(attention_decoder.parameters(),opt.grad_clip)
encoder_optimizer.step()
decoder_optimizer_l.step()
decoder_optimizer_r.step()
attention_decoder_optimizer.step()
#print("end eval training \n ")
#print("=====================\n")
epoch_loss += loss.item()
return epoch_loss / len(iterator)