def compute_masked_loss()

in trainer.py [0:0]


def compute_masked_loss(args, out, Y, corpus, aux_loss):
    # merge batch dim and temporal dim
    out = out.view(-1, out.size(-1))
    Y = Y.view(-1)

    # do not train on specified output tokens
    mask = False
    for w in args.data_omit_label_idx:
        mask += Y.eq(w)
    mask = 1 - mask.float()

    # compute loss
    loss = F.nll_loss(out, Y, reduction="none")

    loss = loss * mask
    loss = loss.sum() / (mask.sum() + 1e-6)
    if torch.is_tensor(aux_loss):
        if args.expire_span:
            # this loss has no correspondance to input tokens
            aux_loss = aux_loss.mean()
        else:
            aux_loss = aux_loss.view(-1)
            aux_loss = aux_loss * mask
            aux_loss = aux_loss.sum() / (mask.sum() + 1e-6)

    if hasattr(corpus, "train_labels"):
        # compute acc
        _, pred = out.max(dim=1)
        err = Y.ne(pred).float()
        err = err * mask
        err = err.sum() / (mask.sum() + 1e-6)
    else:
        err = -1
    return loss, aux_loss, err