in trainer.py [0:0]
def compute_total_loss(args, out, Y, corpus, aux_loss):
if args.data_omit_label_idx is not None:
return compute_masked_loss(args, out, Y, corpus, aux_loss)
# merge batch dim and temporal dim
out = out.view(-1, out.size(-1))
Y = Y.view(-1)
# compute loss
loss = F.nll_loss(out, Y)
if torch.is_tensor(aux_loss):
aux_loss = aux_loss.mean()
if hasattr(corpus, "train_labels"):
# compute acc
_, pred = out.max(dim=1)
err = Y.ne(pred).float().mean()
else:
err = -1
return loss, aux_loss, err