def _batch_ops()

in src/mlm/scorers.py [0:0]


    def _batch_ops(self, batch, batch_sent_idxs_per_ctx, batch_scores_per_ctx, temp, per_token=False) -> int:

        batch_size = 0

        for ctx_idx, (sent_idxs, token_ids, valid_length) in enumerate(batch):

            ctx = self._ctxs[ctx_idx]
            batch_size += sent_idxs.shape[0]
            token_ids = token_ids.as_in_context(ctx)
            valid_length = valid_length.as_in_context(ctx)

            # out is ((batch size, max_seq_len, vocab size), new states)
            out = self._model(token_ids)

            # Get the probability computed for the correct token
            split_size = token_ids.shape[0]

            # TODO: Manual numerically-stable softmax
            # https://stackoverflow.com/questions/42599498/numercially-stable-softmax
            # Because we only need one scalar
            out = out[0].log_softmax(temperature=temp)

            # Get scores ignoring our version of CLS (the 'G.') and SEP (the '<|endoftext|>' after the terminal 'G.')
            # Recall that the softmaxes here are for predicting the >>next<< token
            batch_sent_idxs_per_ctx[ctx_idx].append(sent_idxs)

            if per_token:
                # Each entry will be a list of scores
                out_final = [None]*out.shape[0]
            else:
                out_final = mx.nd.zeros((out.shape[0],), ctx=ctx)
            for i in range(out.shape[0]):
                # Get scores ignoring our version of CLS and SEP ('<|endoftext|>')
                # Recall that the softmaxes here are for predicting the >>next<< token
                out_final_temp = out[i, list(range(valid_length[i].asscalar()-2)), token_ids[i, 1:(valid_length[i].asscalar()-1)]]
                if per_token:
                    out_final[i] = out_final_temp.asnumpy().tolist()
                else:
                    out_final[i] = out_final_temp.sum()
            batch_scores_per_ctx[ctx_idx].append(out_final)

        return batch_size