in trainer.py [0:0]
def _train_step(model, X, Y, h_cache, eval_only, loss_div=1):
"""Single training step."""
out, h_cache, dummy_loss = model(X, h_cache, target=Y)
if model.module.adapt_io:
loss = out.mean() + dummy_loss.sum()
else:
out = out.view(-1, out.size(-1))
loss = torch.nn.functional.nll_loss(out, Y.view(-1))
loss_value = loss.item() / loss_div
if not eval_only:
# loss term from adaptive-span
if model.module.layers[0].attn.attn.adapt_span_enabled:
loss += sum(layer.attn.attn.adaptive_span.get_loss()
for layer in model.module.layers)
(loss / loss_div).backward()
return loss_value, h_cache