def bin()

in src/mlm/scorers.py [0:0]


    def bin(self, corpus: Corpus, temp: float = 1.0, split_size: int = 2000, ratio: float = 0, num_workers: int = 10) -> List[float]:

        ctx_cpu = mx.Context('cpu')

        # Turn corpus into a BERT-ready Dataset
        dataset = self.corpus_to_dataset(corpus)

        # Turn Dataset into Dataloader
        batchify_fn = btf.Tuple(btf.Stack(dtype='int32'), btf.Pad(pad_val=self._vocab.token_to_idx[self._vocab.padding_token], dtype='int32'),
                              btf.Stack(dtype='float32'), btf.Stack(dtype='int32'),
                              btf.Stack(dtype='int32'), btf.Stack(dtype='float32'))

        # TODO: There is a 'by-design' bug in FixedBucketSampler with num_shards > 0, where it silently reuses the last utterances:
        # https://github.com/dmlc/gluon-nlp/blame/b1b61d3f90cf795c7b48b6d109db7b7b96fa21ff/src/gluonnlp/data/sampler.py#L398
        # batch_sampler = nlp.data.sampler.FixedBucketSampler([sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=len(self._ctxs), shuffle=False)
        # Hence, we use num_shards = 0 and do gluon's split_data
        batch_sampler = nlp.data.sampler.FixedBucketSampler([sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=0, shuffle=False)

        logging.info(batch_sampler.stats())
        dataloader = nlp.data.ShardedDataLoader(dataset, pin_memory=True, batch_sampler=batch_sampler, batchify_fn=batchify_fn, num_workers=num_workers, thread_pool=True)

        max_length = 256
        # Compute bins
        # First axis is sentence length
        bin_counts = np.zeros((max_length, max_length))
        bin_counts_per_ctx = [mx.nd.zeros((max_length, max_length), ctx=ctx) for ctx in self._ctxs]
        bin_sums = np.zeros((max_length, max_length))
        bin_sums_per_ctx = [mx.nd.zeros((max_length, max_length), ctx=ctx) for ctx in self._ctxs]

        # Compute sum (assumes dataset is in order)
        prev_sent_idx = None
        true_tok_lens = []
        for (curr_sent_idx, _, valid_length, _, _, _) in dataset:
            if curr_sent_idx != prev_sent_idx:
                prev_sent_idx = curr_sent_idx
                true_tok_lens.append(valid_length - 2)

        sent_count = 0
        batch_log_interval = 20

        # For now just predicts the first non-cls token
        for batch_id, batch in enumerate(dataloader):

            batch_size = 0

            # TODO: Write tests about batching over multiple GPUs and getting the same scores
            # TODO: SEE COMMENT ABOVE REGARDING FIXEDBUCKETSAMPLER
            batch = zip(*[mx.gluon.utils.split_data(batch_compo, len(self._ctxs), batch_axis=0, even_split=False) for batch_compo in batch])

            for ctx_idx, (sent_idxs, token_ids, valid_length, masked_positions, token_masked_ids, normalization) in enumerate(batch):

                ctx = self._ctxs[ctx_idx]
                batch_size += sent_idxs.shape[0]
                token_ids = token_ids.as_in_context(ctx)
                valid_length = valid_length.as_in_context(ctx)
                segment_ids = mx.nd.zeros(shape=token_ids.shape, ctx=ctx)
                masked_positions = masked_positions.as_in_context(ctx).reshape(-1, 1)
                out = self._model(token_ids, segment_ids, valid_length, masked_positions)

                # Get the probability computed for the correct token
                split_size = token_ids.shape[0]
                # out[0] contains the representations
                # out[1] is what contains the distribution for the masked
                out = out[1].log_softmax(temperature=temp)

                token_masked_ids = token_masked_ids.as_in_context(ctx).reshape(-1)
                for i in range(out.shape[0]):
                    num_bins = int(valid_length[i].asscalar())-2
                    bin_counts_per_ctx[ctx_idx][num_bins, masked_positions[i]-1] += 1
                    bin_sums_per_ctx[ctx_idx][num_bins, masked_positions[i]-1] += out[i, 0, token_masked_ids[i]]
                    if token_masked_ids[i].asscalar() == 1012:
                        import pdb; pdb.set_trace()

            # Progress
            sent_count += batch_size
            if (batch_id+1) % batch_log_interval == 0:
                logging.info("{} sents of {}, batch {} of {}".format(sent_count, len(dataset), batch_id+1, len(batch_sampler)))

        # Accumulate the counts
        for ctx_idx in range(len(self._ctxs)):
            bin_counts += bin_counts_per_ctx[ctx_idx].asnumpy()
            bin_sums += bin_sums_per_ctx[ctx_idx].asnumpy()

        return bin_counts, bin_sums