def bin()

in src/mlm/scorers.py [0:0]


    def bin(self, corpus: Corpus, temp: float = 1.0, split_size: int = 2000, ratio: float = 1, num_workers: int = 10) -> List[float]:

        ctx_cpu = mx.Context('cpu')

        # Get MXNet data objects
        dataset, batch_sampler, dataloader = self._corpus_to_data(corpus, split_size, ratio, num_workers)

        # Get number of tokens
        true_tok_lens = self._true_tok_lens(dataset)

        max_length = 256
        # Compute bins
        # First axis is sentence length
        bin_counts = np.zeros((max_length, max_length))
        bin_counts_per_ctx = [mx.nd.zeros((max_length, max_length), ctx=ctx) for ctx in self._ctxs]
        bin_sums = np.zeros((max_length, max_length))
        bin_sums_per_ctx = [mx.nd.zeros((max_length, max_length), ctx=ctx) for ctx in self._ctxs]

        sent_count = 0
        batch_log_interval = 20

        # For now just predicts the first non-cls token
        for batch_id, batch in enumerate(dataloader):

            batch = self._split_batch(batch)

            batch_size = self._bin_ops(batch, bin_counts_per_ctx, bin_sums_per_ctx, temp)

            # Progress
            sent_count += batch_size
            if (batch_id+1) % batch_log_interval == 0:
                logging.info("{} sents of {}, batch {} of {}".format(sent_count, len(dataset), batch_id+1, len(batch_sampler)))

        # Accumulate the counts
        for ctx_idx in range(len(self._ctxs)):
            bin_counts += bin_counts_per_ctx[ctx_idx].asnumpy()
            bin_sums += bin_sums_per_ctx[ctx_idx].asnumpy()

        return bin_counts, bin_sums