in src/mlm/scorers.py [0:0]
def bin(self, corpus: Corpus, temp: float = 1.0, split_size: int = 2000, ratio: float = 1, num_workers: int = 10) -> List[float]:
ctx_cpu = mx.Context('cpu')
# Get MXNet data objects
dataset, batch_sampler, dataloader = self._corpus_to_data(corpus, split_size, ratio, num_workers)
# Get number of tokens
true_tok_lens = self._true_tok_lens(dataset)
max_length = 256
# Compute bins
# First axis is sentence length
bin_counts = np.zeros((max_length, max_length))
bin_counts_per_ctx = [mx.nd.zeros((max_length, max_length), ctx=ctx) for ctx in self._ctxs]
bin_sums = np.zeros((max_length, max_length))
bin_sums_per_ctx = [mx.nd.zeros((max_length, max_length), ctx=ctx) for ctx in self._ctxs]
sent_count = 0
batch_log_interval = 20
# For now just predicts the first non-cls token
for batch_id, batch in enumerate(dataloader):
batch = self._split_batch(batch)
batch_size = self._bin_ops(batch, bin_counts_per_ctx, bin_sums_per_ctx, temp)
# Progress
sent_count += batch_size
if (batch_id+1) % batch_log_interval == 0:
logging.info("{} sents of {}, batch {} of {}".format(sent_count, len(dataset), batch_id+1, len(batch_sampler)))
# Accumulate the counts
for ctx_idx in range(len(self._ctxs)):
bin_counts += bin_counts_per_ctx[ctx_idx].asnumpy()
bin_sums += bin_sums_per_ctx[ctx_idx].asnumpy()
return bin_counts, bin_sums