in trainer.py [0:0]
def _train_batch(model, optimizer, scheduler, X, Y, h_cache,
eval_only, batch_split):
"""Train on a batch."""
optimizer.zero_grad()
if batch_split == 1:
# process a batch in a single step (default behaviour)
loss_value, h_cache = _train_step(model, X, Y, h_cache, eval_only)
else:
# split a batch into multiple pieces that each can fit in memory
assert X.size(0) % batch_split == 0
split_size = X.size(0) // batch_split
loss_value = 0
h_cache_list = []
for split_ind in range(batch_split):
split_slice = slice(split_ind*split_size, (split_ind+1)*split_size)
split_h_cache = [h[split_slice,:,:] for h in h_cache]
split_loss_value, split_h_cache = _train_step(
model, X[split_slice,:], Y[split_slice],
split_h_cache, eval_only, batch_split)
loss_value += split_loss_value
h_cache_list.append(split_h_cache)
h_cache = [
torch.cat(
[h_cache_list[i][l] for i in range(batch_split)]
, dim=0) for l in range(len(h_cache))]
if not eval_only:
if scheduler is not None:
scheduler.step()
if optimizer.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), optimizer.grad_clip)
optimizer.step()
# make sure span parameters are in a correct range
if model.module.layers[0].attn.attn.adapt_span_enabled:
for layer in model.module.layers:
layer.attn.attn.adaptive_span.clamp_param()
return loss_value, h_cache