def score()

in src/mlm/scorers.py [0:0]


    def score(self, corpus: Corpus, temp: float = 1.0, split_size: int = 2000, ratio: float = 0, per_token: bool = False) -> List[float]:

        assert temp == 1.0

        # Turn corpus into a BERT-ready Dataset
        dataset = self.corpus_to_dataset(corpus)

        # Turn Dataset into Dataloader
        batchify_fn = btf_generic.Tuple(btf_generic.Stack(dtype='int32'), btf_generic.Pad(pad_val=self._tokenizer._convert_token_to_id(self._tokenizer.pad_token), dtype='long'),
                              btf_generic.Stack(dtype='long'), btf_generic.Stack(dtype='long'),
                              btf_generic.Stack(dtype='long'), btf_generic.Stack(dtype='long'))

        # TODO: There is a 'by-design' bug in FixedBucketSampler with num_shards > 0, where it silently reuses the last utterances:
        # https://github.com/dmlc/gluon-nlp/blame/b1b61d3f90cf795c7b48b6d109db7b7b96fa21ff/src/gluonnlp/data/sampler.py#L398
        # batch_sampler = nlp.data.sampler.FixedBucketSampler([sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=len(self._ctxs), shuffle=False)
        # Hence, we use num_shards = 0 and do gluon's split_data
        batch_sampler = nlp.data.sampler.FixedBucketSampler([sent_tuple[2] for sent_tuple in dataset], batch_size=split_size, ratio=ratio, num_shards=0, shuffle=False)

        logging.info(batch_sampler.stats())

        # dataloader = nlp.data.ShardedDataLoader(dataset, pin_memory=True, batch_sampler=batch_sampler, batchify_fn=batchify_fn, num_workers=num_workers, thread_pool=True)
        dataloader = nlp.data.ShardedDataLoader(dataset, batch_sampler=batch_sampler, batchify_fn=batchify_fn)

        # Compute sum (assumes dataset is in order)
        prev_sent_idx = None
        true_tok_lens = []
        for (curr_sent_idx, _, valid_length, _, _, _) in dataset:
            if curr_sent_idx != prev_sent_idx:
                prev_sent_idx = curr_sent_idx
                true_tok_lens.append(valid_length - 2)

        # Compute scores (total or per-position)
        if per_token:
            scores_per_token = [[None]*(true_tok_len+2) for true_tok_len in true_tok_lens]
        else:
            scores = np.zeros((len(corpus),))

        sent_count = 0
        batch_log_interval = 20

        batch_score_accumulation = 1
        batch_sent_idxs_per_ctx = [[] for ctx in self._ctxs]
        batch_scores_per_ctx = [[] for ctx in self._ctxs]
        batch_masked_positions_per_ctx = [[] for ctx in self._ctxs]

        def sum_accumulated_scores():
            for ctx_idx in range(len(self._ctxs)):
                for batch_sent_idxs, batch_scores, batch_masked_positions in zip(batch_sent_idxs_per_ctx[ctx_idx], batch_scores_per_ctx[ctx_idx], batch_masked_positions_per_ctx[ctx_idx]):
                    if per_token:
                        # Slow; only use when necessary
                        for batch_sent_idx, batch_score, batch_masked_position in zip(batch_sent_idxs, batch_scores, batch_masked_positions):
                            # scores_per_token[batch_sent_idx.asscalar()][int(batch_masked_position.asscalar())] = batch_score.asscalar().item()
                            scores_per_token[batch_sent_idx][batch_masked_position.cpu().numpy().item()] = batch_score.cpu().numpy().item()
                    else:
                        # np.add.at(scores, batch_sent_idxs.asnumpy(), batch_scores.asnumpy())
                        np.add.at(scores, batch_sent_idxs, batch_scores.cpu().numpy())
                batch_sent_idxs_per_ctx[ctx_idx] = []
                batch_scores_per_ctx[ctx_idx] = []
                batch_masked_positions_per_ctx[ctx_idx] = []

        # For now just predicts the first non-cls token
        for batch_id, batch in enumerate(dataloader):

            batch_size = 0

            for ctx_idx, (sent_idxs, token_ids, valid_length, masked_positions, token_masked_ids, normalization) in enumerate((batch,)):

                ctx = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                batch_size += sent_idxs.shape[0]

                # TODO: Super inefficient where we go from MXNet to NumPy to PyTorch

                with torch.no_grad():

                    token_ids = torch.tensor(token_ids)
                    valid_length = torch.tensor(valid_length)
                    masked_positions = torch.tensor(masked_positions).reshape(-1, 1)
                    token_masked_ids = torch.tensor(token_masked_ids).reshape(-1)

                    token_ids = token_ids.to(ctx)
                    valid_length = valid_length.to(ctx)
                    masked_positions = masked_positions.to(ctx)
                    token_masked_ids = token_masked_ids.to(ctx)

                    split_size = token_ids.shape[0]

                    if isinstance(self._model.module, AlbertForMaskedLMOptimized) or \
                        isinstance(self._model.module, BertForMaskedLMOptimized) or \
                        isinstance(self._model.module, DistilBertForMaskedLMOptimized):
                        # Because BERT does not take a length parameter
                        alen = torch.arange(token_ids.shape[1], dtype=torch.long)
                        alen = alen.to(ctx)
                        mask = alen < valid_length[:, None]
                        out = self._model(input_ids=token_ids, attention_mask=mask, select_positions=masked_positions)
                        out = out[0].squeeze()
                    elif isinstance(self._model.module, transformers.BertForMaskedLM):
                        # Because BERT does not take a length parameter
                        alen = torch.arange(token_ids.shape[1], dtype=torch.long)
                        alen = alen.to(ctx)
                        mask = alen < valid_length[:, None]
                        out = self._model(input_ids=token_ids, attention_mask=mask)
                        # out[0] is what contains the distribution for the masked (batch_size, sequence_length, config.vocab_size)
                        # Reindex to only get the distributions at the masked positions (batch_size, config.vocab_size)
                        out = out[0][list(range(split_size)),masked_positions.reshape(-1),:]
                    elif isinstance(self._model.module, transformers.XLMWithLMHeadModel):
                        if self._lang is not None and self._tokenizer.lang2id is not None:
                            langs = torch.ones_like(token_ids)*self._tokenizer.lang2id[self._lang]
                        else:
                            langs = None
                        out = self._model(input_ids=token_ids, lengths=valid_length, langs=langs)
                        # out[0] is what contains the distribution for the masked (batch_size, sequence_length, config.vocab_size)
                        # Reindex to only get the distributions at the masked positions (batch_size, config.vocab_size)
                        out = out[0][list(range(split_size)),masked_positions.reshape(-1),:]
                    else:
                        raise ValueError

                    # TODO: Manual numerically-stable softmax
                    # https://stackoverflow.com/questions/42599498/numercially-stable-softmax
                    # Because we only need one scalar
                    out = out.log_softmax(dim=-1)

                    # Get the probability computed for the correct token
                    # Save the scores at the masked indices
                    batch_sent_idxs_per_ctx[ctx_idx].append(sent_idxs)
                    out = out[list(range(split_size)), token_masked_ids]
                    batch_scores_per_ctx[ctx_idx].append(out)
                    batch_masked_positions_per_ctx[ctx_idx].append(masked_positions)

            # Ideally we'd accumulate the scores when possible, but something like the below won't work
            # > scores[sent_idxs] += out
            # See In[21] in https://jakevdp.github.io/PythonDataScienceHandbook/02.07-fancy-indexing.html.
            # Hence, aggregation is done synchronously, every so often
            # (though batch_score_accumulation = 1 seems best, since bucketing is effective in reducing GPU disparity)
            if len(batch_sent_idxs_per_ctx[0]) == batch_score_accumulation:   
                sum_accumulated_scores()

            # Progress
            sent_count += batch_size
            if (batch_id+1) % batch_log_interval == 0:
                logging.info("{} sents of {}, batch {} of {}".format(sent_count, len(dataset), batch_id+1, len(batch_sampler)))

        # TODO: Test score accumulation
        # In case there are leftovers
        sum_accumulated_scores()

        if per_token:
            return scores_per_token, true_tok_lens
        else:
            return scores.tolist(), true_tok_lens