def _loglikelihood_tokens()

in lm_eval/models/nemo_lm.py [0:0]


    def _loglikelihood_tokens(self, requests, disable_tqdm=False):
        res = []

        def _collate(x):
            toks = x[1] + x[2]
            return -len(toks), tuple(toks)

        re_ord = Collator(requests, sort_fn=_collate)
        chunks = re_ord.get_batched(n=self.batch_size, batch_fn=None)
        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running loglikelihood requests",
        )
        for chunk in chunks:
            inps = []
            ctxlens = []
            contlens = []

            for _, context_enc, continuation_enc in chunk:
                # Leave one token for generation. Tokens_to_generate = 0 breaks NeMo.
                inp = (context_enc + continuation_enc)[-(self.max_length - 1) :]

                ctxlen = len(context_enc) - max(
                    0, len(context_enc) + len(continuation_enc) - (self.max_length - 1)
                )
                ctxlens.append(ctxlen)
                contlens.append(len(continuation_enc))

                inps.append(self.tok_decode(inp))

            output = self.generate(
                self.model,
                inputs=inps,
                tokens_to_generate=1,
                min_tokens_to_generate=1,
                compute_logprob=True,
                all_probs=True,
            )

            batch_token_ids = np.asarray(output["token_ids"])[:, :-1]
            batch_logprobs = output["logprob"][:, :-1]
            batch_full_logprob = output["full_logprob"][:, :-1, :]

            # Compute greedy tokens for entire batch rather than calling it with proper ctxlen for each sample.
            # Additional tokens for each sample will be trimmed later.
            min_ctxlen = min(ctxlens)

            # Use min_ctxlen-1 instead of min_ctxlen since full_logprobs are not returns for the first token.
            batch_greedy_tokens = (
                torch.argmax(batch_full_logprob[:, min_ctxlen - 1 :, :], -1)
                .cpu()
                .numpy()
            )

            for token_ids, greedy_tokens, logprobs, ctxlen, contlen, (
                cache_key,
                _,
                _,
            ) in zip(
                batch_token_ids,
                batch_greedy_tokens,
                batch_logprobs,
                ctxlens,
                contlens,
                chunk,
            ):
                # Trim at contlen since shorter contexts in a batch will have more than one token generated.
                # Use ctxlen-1 instead of ctxlen same as for full_logprob in batch_greedy_tokens calculation
                logprobs = (logprobs[ctxlen - 1 :])[:contlen]
                logprob = sum(logprobs).tolist()

                continuation_tokens = (token_ids[ctxlen:])[:contlen]
                len_diff = ctxlen - min_ctxlen
                is_greedy = continuation_tokens == (greedy_tokens[len_diff:])[:contlen]
                if not isinstance(is_greedy, bool):
                    is_greedy = is_greedy.all()
                answer = (logprob, is_greedy)

                if cache_key is not None:
                    self.cache_hook.add_partial("loglikelihood", cache_key, answer)

                res.append(answer)
                pbar.update(1)

        pbar.close()

        return re_ord.get_original(res)