in src/mlm/scorers.py [0:0]
def _batch_ops(self, batch, batch_sent_idxs_per_ctx, batch_scores_per_ctx, temp) -> int:
batch_size = 0
losses = []
with mx.autograd.record():
for ctx_idx, (sent_idxs, token_ids, valid_length, scores) in enumerate(batch):
ctx = self._ctxs[ctx_idx]
batch_size += sent_idxs.shape[0]
token_ids = token_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx)
scores = scores.as_in_context(ctx)
segment_ids = mx.nd.zeros(shape=token_ids.shape, ctx=ctx)
out = self._model(token_ids, segment_ids, valid_length)
loss = self._loss(out, scores).sum()
losses.append(loss)
for loss in losses:
loss.backward()
# Synchronous
batch_loss = sum([loss.as_in_context(mx.cpu()) for loss in losses])
losses = []
# Gradient clipping
self._trainer.allreduce_grads()
nlp.utils.clip_grad_global_norm(self._params, 1)
# TODO: What is correct # of steps?
# TODO: Stale grad?
mx.nd.waitall()
self._trainer.update(batch_size, ignore_stale_grad=True)
return batch_size, batch_loss