baseline_model/data_utils/train_tree_encoder_v2.py [42:77]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        enc_batch = batch.src #.transpose(1,0)
        dec_tree_batch  = batch.trg
        enclen = batch.enclen
        # enc_max_len  = enc_batch.size(1) #batch.src.shape[1]
        enc_max_len  = opt.enc_seq_length #enc_batch.size(1)
        enc_outputs = torch.zeros((len(enc_batch), enc_max_len, encoder.hidden_size), requires_grad=True)

        if using_gpu:
            enc_outputs = enc_outputs.cuda()
        enc_s = {}
        for j in range(opt.enc_seq_length + 1):
            enc_s[j] = {}

        dec_s = {}
        for i in range(opt.dec_seq_length + 1):
            dec_s[i] = {}
            for j in range(3):
                dec_s[i][j] = {}

        for i in range(1, 3):
            enc_s[0][i] = torch.zeros((opt.batch_size, opt.rnn_size), dtype=torch.float, requires_grad=True)
            if using_gpu:
                enc_s[0][i] = enc_s[0][i].cuda()

        # pdb.set_trace()
        # TODO:change this part
        # import time
        # start = time.time()
        for i in range(enc_max_len):
            enc_s[i+1][1], enc_s[i+1][2] = encoder(enc_batch, i, enc_s[i][1], enc_s[i][2])
            enc_outputs[:, i, :] = enc_s[i+1][2]

        # end = time.time()
        # print("time encoding: " + str(end-start) )
        # tree decode
        queue_tree = {}
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



baseline_model/data_utils/train_tree_encoder_v2.py [231:259]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        enc_batch = batch.src # .transpose(1,0)
        dec_tree_batch  = batch.trg
        enclen = batch.enclen
        enc_max_len  = opt.enc_seq_length #enc_batch.size(1)
        # enc_outputs = torch.zeros((enc_batch.size(0), enc_max_len, encoder.hidden_size), requires_grad=True)
        enc_outputs = torch.zeros((len(enc_batch), enc_max_len, encoder.hidden_size), requires_grad=True)
        if using_gpu:
            enc_outputs = enc_outputs.cuda()
        enc_s = {}
        for j in range(opt.enc_seq_length + 1):
            enc_s[j] = {}

        dec_s = {}
        for i in range(opt.dec_seq_length + 1):
            dec_s[i] = {}
            for j in range(3):
                dec_s[i][j] = {}

        for i in range(1, 3):
            enc_s[0][i] = torch.zeros((opt.batch_size, opt.rnn_size), dtype=torch.float, requires_grad=True)
            if using_gpu:
                enc_s[0][i] = enc_s[0][i].cuda()

        # TODO:change this part
        for i in range(enc_max_len):
            enc_s[i+1][1], enc_s[i+1][2] = encoder(enc_batch, i, enc_s[i][1], enc_s[i][2])
            enc_outputs[:, i, :] = enc_s[i+1][2]
        # tree decode
        queue_tree = {}
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



