in src/mlm/scorers.py [0:0]
def score(self, corpus: Corpus, temp: float = 1.0, split_size: int = 2000, ratio: float = 1, num_workers: int = 10, per_token: bool = False) -> List[float]:
ctx_cpu = mx.Context('cpu')
# Get MXNet data objects
dataset, batch_sampler, dataloader = self._corpus_to_data(corpus, split_size, ratio, num_workers)
# Get number of tokens
true_tok_lens = self._true_tok_lens(dataset)
# Compute scores (total or per-position)
if per_token:
scores_per_token = [[None]*(true_tok_len+2) for true_tok_len in true_tok_lens]
else:
scores = np.zeros((len(corpus),))
sent_count = 0
batch_log_interval = 20
batch_score_accumulation = 1
batch_sent_idxs_per_ctx = [[] for ctx in self._ctxs]
batch_scores_per_ctx = [[] for ctx in self._ctxs]
def sum_accumulated_scores():
for ctx_idx in range(len(self._ctxs)):
for batch_sent_idxs, batch_scores in zip(batch_sent_idxs_per_ctx[ctx_idx], batch_scores_per_ctx[ctx_idx]):
if per_token:
# Slow; only use when necessary
for batch_sent_idx, batch_score in zip(batch_sent_idxs, batch_scores):
scores_per_token[batch_sent_idx.asscalar()] = batch_score
else:
np.add.at(scores, batch_sent_idxs.asnumpy(), batch_scores.asnumpy())
batch_sent_idxs_per_ctx[ctx_idx] = []
batch_scores_per_ctx[ctx_idx] = []
# For now just predicts the first non-cls token
for batch_id, batch in enumerate(dataloader):
batch = self._split_batch(batch)
batch_size = self._batch_ops(batch, batch_sent_idxs_per_ctx, batch_scores_per_ctx, temp, per_token=per_token)
# Ideally we'd accumulate the scores when possible, but something like the below won't work
# > scores[sent_idxs] += out
# See In[21] in https://jakevdp.github.io/PythonDataScienceHandbook/02.07-fancy-indexing.html.
# Hence, aggregation is done synchronously, every so often
# (though batch_score_accumulation = 1 seems best, since bucketing is effective in reducing GPU disparity)
if len(batch_sent_idxs_per_ctx[0]) == batch_score_accumulation:
sum_accumulated_scores()
# Progress
sent_count += batch_size
if (batch_id+1) % batch_log_interval == 0:
logging.info("{} sents of {}, batch {} of {}".format(sent_count, len(dataset), batch_id+1, len(batch_sampler)))
# In case there are leftovers
sum_accumulated_scores()
if per_token:
return scores_per_token, true_tok_lens
else:
return scores.tolist(), true_tok_lens