in trainer.py [0:0]
def compute_masked_loss(args, out, Y, corpus, aux_loss):
# merge batch dim and temporal dim
out = out.view(-1, out.size(-1))
Y = Y.view(-1)
# do not train on specified output tokens
mask = False
for w in args.data_omit_label_idx:
mask += Y.eq(w)
mask = 1 - mask.float()
# compute loss
loss = F.nll_loss(out, Y, reduction="none")
loss = loss * mask
loss = loss.sum() / (mask.sum() + 1e-6)
if torch.is_tensor(aux_loss):
if args.expire_span:
# this loss has no correspondance to input tokens
aux_loss = aux_loss.mean()
else:
aux_loss = aux_loss.view(-1)
aux_loss = aux_loss * mask
aux_loss = aux_loss.sum() / (mask.sum() + 1e-6)
if hasattr(corpus, "train_labels"):
# compute acc
_, pred = out.max(dim=1)
err = Y.ne(pred).float()
err = err * mask
err = err.sum() / (mask.sum() + 1e-6)
else:
err = -1
return loss, aux_loss, err