def _batch_ops()

in src/mlm/scorers.py [0:0]


    def _batch_ops(self, batch, batch_sent_idxs_per_ctx, batch_scores_per_ctx, temp) -> int:

        batch_size = 0

        losses = []

        with mx.autograd.record():

            for ctx_idx, (sent_idxs, token_ids, valid_length, scores) in enumerate(batch):

                ctx = self._ctxs[ctx_idx]
                batch_size += sent_idxs.shape[0]
                token_ids = token_ids.as_in_context(ctx)
                valid_length = valid_length.as_in_context(ctx)
                scores = scores.as_in_context(ctx)
                segment_ids = mx.nd.zeros(shape=token_ids.shape, ctx=ctx)
                out = self._model(token_ids, segment_ids, valid_length)
                loss = self._loss(out, scores).sum()

                losses.append(loss)
            
            for loss in losses:
                loss.backward()

        # Synchronous
        batch_loss = sum([loss.as_in_context(mx.cpu()) for loss in losses])
        losses = []

        # Gradient clipping
        self._trainer.allreduce_grads()
        nlp.utils.clip_grad_global_norm(self._params, 1)
        # TODO: What is correct # of steps?
        # TODO: Stale grad?
        mx.nd.waitall()
        self._trainer.update(batch_size, ignore_stale_grad=True)

        return batch_size, batch_loss