in src/mlm/scorers.py [0:0]
def _batch_ops(self, batch, batch_sent_idxs_per_ctx, batch_scores_per_ctx, temp, per_token=False) -> int:
batch_size = 0
for ctx_idx, (sent_idxs, token_ids, valid_length) in enumerate(batch):
ctx = self._ctxs[ctx_idx]
batch_size += sent_idxs.shape[0]
token_ids = token_ids.as_in_context(ctx)
valid_length = valid_length.as_in_context(ctx)
# out is ((batch size, max_seq_len, vocab size), new states)
out = self._model(token_ids)
# Get the probability computed for the correct token
split_size = token_ids.shape[0]
# TODO: Manual numerically-stable softmax
# https://stackoverflow.com/questions/42599498/numercially-stable-softmax
# Because we only need one scalar
out = out[0].log_softmax(temperature=temp)
# Get scores ignoring our version of CLS (the 'G.') and SEP (the '<|endoftext|>' after the terminal 'G.')
# Recall that the softmaxes here are for predicting the >>next<< token
batch_sent_idxs_per_ctx[ctx_idx].append(sent_idxs)
if per_token:
# Each entry will be a list of scores
out_final = [None]*out.shape[0]
else:
out_final = mx.nd.zeros((out.shape[0],), ctx=ctx)
for i in range(out.shape[0]):
# Get scores ignoring our version of CLS and SEP ('<|endoftext|>')
# Recall that the softmaxes here are for predicting the >>next<< token
out_final_temp = out[i, list(range(valid_length[i].asscalar()-2)), token_ids[i, 1:(valid_length[i].asscalar()-1)]]
if per_token:
out_final[i] = out_final_temp.asnumpy().tolist()
else:
out_final[i] = out_final_temp.sum()
batch_scores_per_ctx[ctx_idx].append(out_final)
return batch_size