def _train_batch()

in trainer.py [0:0]


def _train_batch(model, optimizer, scheduler, X, Y, h_cache,
                 eval_only, batch_split):
    """Train on a batch."""

    optimizer.zero_grad()

    if batch_split == 1:
        # process a batch in a single step (default behaviour)
        loss_value, h_cache = _train_step(model, X, Y, h_cache, eval_only)
    else:
        # split a batch into multiple pieces that each can fit in memory
        assert X.size(0) % batch_split == 0
        split_size = X.size(0) // batch_split
        loss_value = 0
        h_cache_list = []
        for split_ind in range(batch_split):
            split_slice = slice(split_ind*split_size, (split_ind+1)*split_size)
            split_h_cache = [h[split_slice,:,:] for h in h_cache]
            split_loss_value, split_h_cache = _train_step(
                model, X[split_slice,:], Y[split_slice],
                split_h_cache, eval_only, batch_split)
            loss_value += split_loss_value
            h_cache_list.append(split_h_cache)
        h_cache = [
            torch.cat(
                [h_cache_list[i][l] for i in range(batch_split)]
            , dim=0) for l in range(len(h_cache))]

    if not eval_only:
        if scheduler is not None:
            scheduler.step()
        if optimizer.grad_clip > 0:
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), optimizer.grad_clip)
        optimizer.step()

        # make sure span parameters are in a correct range
        if model.module.layers[0].attn.attn.adapt_span_enabled:
            for layer in model.module.layers:
                layer.attn.attn.adaptive_span.clamp_param()

    return loss_value, h_cache