def compute_total_loss()

in trainer.py [0:0]


def compute_total_loss(args, out, Y, corpus, aux_loss):
    if args.data_omit_label_idx is not None:
        return compute_masked_loss(args, out, Y, corpus, aux_loss)

    # merge batch dim and temporal dim
    out = out.view(-1, out.size(-1))
    Y = Y.view(-1)

    # compute loss
    loss = F.nll_loss(out, Y)

    if torch.is_tensor(aux_loss):
        aux_loss = aux_loss.mean()

    if hasattr(corpus, "train_labels"):
        # compute acc
        _, pred = out.max(dim=1)
        err = Y.ne(pred).float().mean()
    else:
        err = -1
    return loss, aux_loss, err