def _train_step()

in trainer.py [0:0]


def _train_step(model, X, Y, h_cache, eval_only, loss_div=1):
    """Single training step."""

    out, h_cache, dummy_loss = model(X, h_cache, target=Y)
    if model.module.adapt_io:
        loss = out.mean() + dummy_loss.sum()
    else:
        out = out.view(-1, out.size(-1))
        loss = torch.nn.functional.nll_loss(out, Y.view(-1))
    loss_value = loss.item() / loss_div

    if not eval_only:
        # loss term from adaptive-span
        if model.module.layers[0].attn.attn.adapt_span_enabled:
            loss += sum(layer.attn.attn.adaptive_span.get_loss()
                        for layer in model.module.layers)

        (loss / loss_div).backward()

    return loss_value, h_cache