baseline_model/data_utils/train.py [167:288]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            trg_seq, gold = map(lambda x: x.to(device), patch_trg(batch.trg, trg_pad_idx))
            output_dim = output.shape[-1]

            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)

            loss, n_correct, n_word = cal_performance(
                output, gold, trg_pad_idx, smoothing=False)
            # loss = criterion(output, trg)

            n_word_total += n_word
            n_word_correct += n_correct
            epoch_loss += loss.item()

    loss_per_word = epoch_loss/n_word_total
    accuracy = n_word_correct/n_word_total

    return loss_per_word, accuracy
    # return epoch_loss / len(iterator)

# def evaluate(model, iterator, criterion,dump_file=False):
#
#     model.eval()
#     epoch_loss = 0
#     trg_all = []
#     output_all = []
#     collect_hidden_all = []
#     with torch.no_grad():
#         for i, batch in enumerate(iterator):
#
#             src = batch.src
#             trg_raw = batch.trg
#             output_raw, collect_hidden,_ = model(src, trg_raw, 0) #turn off teacher forcing
#             # output_raw = model(src, trg_raw, 0) #turn off teacher forcing
#
#             #trg = [trg sent len, batch size]
#             #output = [trg sent len, batch size, output dim]
#
#             output = output_raw[1:].view(-1, output_raw.shape[-1])
#             trg = trg_raw[1:].view(-1)
#
#             #trg = [(trg sent len - 1) * batch size]
#             #output = [(trg sent len - 1) * batch size, output dim]
#
#             loss = criterion(output, trg)
#
#             epoch_loss += loss.item()
#             if(dump_file is True):
#                 collect_hidden_all.append(collect_hidden)
#                 trg_all.append(trg_raw)
#                 output_all.append(torch.max(output_raw,2)[1])
#     return trg_all, output_all, collect_hidden_all, epoch_loss / len(iterator)

def evaluate_att(model, iterator, criterion,dump_file=False):

    model.eval()

    epoch_loss = 0

    trg_all = []
    output_all = []
    with torch.no_grad():
        for i, batch in enumerate(iterator):

            src = batch.src.transpose(1,0)
            trg_raw = batch.trg.transpose(1,0)
            output_raw = model(src, trg_raw[:,:-1])

            #output = [batch size, trg sent len - 1, output dim]
            #trg = [batch size, trg sent len]

            output = output_raw.contiguous().view(-1, output_raw.shape[-1])
            trg = trg_raw[:,1:].contiguous().view(-1)

            #output = [batch size * trg sent len - 1, output dim]
            #trg = [batch size * trg sent len - 1]

            loss = criterion(output, trg)
            epoch_loss += loss.item()

            if(dump_file is True):
                trg_all.append(batch.trg)
                output_all.append(torch.max(output_raw,2)[1])

    return trg_all, output_all, epoch_loss / len(iterator)

def train_att(model, iterator, optimizer, criterion, clip):

    model.train()

    epoch_loss = 0

    for i, batch in enumerate(iterator):

        src = batch.src.transpose(1,0)
        trg = batch.trg.transpose(1,0)
        # src = src.transpose(0,1)
        # src = src.transpose(0,1)

        optimizer.optimizer.zero_grad()

        # print(i)
        # if i == 40 :
        #     pdb.set_trace()
        output = model(src, trg[:,:-1])

        #output = [batch size, trg sent len - 1, output dim]
        #trg = [batch size, trg sent len]

        output = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:,1:].contiguous().view(-1)

        #output = [batch size * trg sent len - 1, output dim]
        #trg = [batch size * trg sent len - 1]

        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



baseline_model/data_utils/train_gnn.py [281:403]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            trg_seq, gold = map(lambda x: x.to(device), patch_trg(batch.trg, trg_pad_idx))
            output_dim = output.shape[-1]

            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)

            loss, n_correct, n_word = cal_performance(
                output, gold, trg_pad_idx, smoothing=False)
            # loss = criterion(output, trg)

            n_word_total += n_word
            n_word_correct += n_correct
            epoch_loss += loss.item()

    loss_per_word = epoch_loss/n_word_total
    accuracy = n_word_correct/n_word_total

    return loss_per_word, accuracy
    # return epoch_loss / len(iterator)

# def evaluate(model, iterator, criterion,dump_file=False):
#
#     model.eval()
#     epoch_loss = 0
#     trg_all = []
#     output_all = []
#     collect_hidden_all = []
#     with torch.no_grad():
#         for i, batch in enumerate(iterator):
#
#             src = batch.src
#             trg_raw = batch.trg
#             output_raw, collect_hidden,_ = model(src, trg_raw, 0) #turn off teacher forcing
#             # output_raw = model(src, trg_raw, 0) #turn off teacher forcing
#
#             #trg = [trg sent len, batch size]
#             #output = [trg sent len, batch size, output dim]
#
#             output = output_raw[1:].view(-1, output_raw.shape[-1])
#             trg = trg_raw[1:].view(-1)
#
#             #trg = [(trg sent len - 1) * batch size]
#             #output = [(trg sent len - 1) * batch size, output dim]
#
#             loss = criterion(output, trg)
#
#             epoch_loss += loss.item()
#             if(dump_file is True):
#                 collect_hidden_all.append(collect_hidden)
#                 trg_all.append(trg_raw)
#                 output_all.append(torch.max(output_raw,2)[1])
#     return trg_all, output_all, collect_hidden_all, epoch_loss / len(iterator)

def evaluate_att(model, iterator, criterion,dump_file=False):

    model.eval()

    epoch_loss = 0

    trg_all = []
    output_all = []
    with torch.no_grad():
        for i, batch in enumerate(iterator):

            src = batch.src.transpose(1,0)
            trg_raw = batch.trg.transpose(1,0)
            output_raw = model(src, trg_raw[:,:-1])

            # pdb.set_trace()
            #output = [batch size, trg sent len - 1, output dim]
            #trg = [batch size, trg sent len]

            output = output_raw.contiguous().view(-1, output_raw.shape[-1])
            trg = trg_raw[:,1:].contiguous().view(-1)

            #output = [batch size * trg sent len - 1, output dim]
            #trg = [batch size * trg sent len - 1]

            loss = criterion(output, trg)
            epoch_loss += loss.item()

            if(dump_file is True):
                trg_all.append(batch.trg)
                output_all.append(torch.max(output_raw,2)[1])

    return trg_all, output_all, epoch_loss / len(iterator)

def train_att(model, iterator, optimizer, criterion, clip):

    model.train()

    epoch_loss = 0

    for i, batch in enumerate(iterator):

        src = batch.src.transpose(1,0)
        trg = batch.trg.transpose(1,0)
        # src = src.transpose(0,1)
        # src = src.transpose(0,1)

        optimizer.optimizer.zero_grad()

        # print(i)
        # if i == 40 :
        #     pdb.set_trace()
        output = model(src, trg[:,:-1])

        #output = [batch size, trg sent len - 1, output dim]
        #trg = [batch size, trg sent len]

        output = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:,1:].contiguous().view(-1)

        #output = [batch size * trg sent len - 1, output dim]
        #trg = [batch size * trg sent len - 1]

        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(iterator)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



