in lm_eval/models/nemo_lm.py [0:0]
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
res = []
def _collate(x):
toks = x[1] + x[2]
return -len(toks), tuple(toks)
re_ord = Collator(requests, sort_fn=_collate)
chunks = re_ord.get_batched(n=self.batch_size, batch_fn=None)
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests",
)
for chunk in chunks:
inps = []
ctxlens = []
contlens = []
for _, context_enc, continuation_enc in chunk:
# Leave one token for generation. Tokens_to_generate = 0 breaks NeMo.
inp = (context_enc + continuation_enc)[-(self.max_length - 1) :]
ctxlen = len(context_enc) - max(
0, len(context_enc) + len(continuation_enc) - (self.max_length - 1)
)
ctxlens.append(ctxlen)
contlens.append(len(continuation_enc))
inps.append(self.tok_decode(inp))
output = self.generate(
self.model,
inputs=inps,
tokens_to_generate=1,
min_tokens_to_generate=1,
compute_logprob=True,
all_probs=True,
)
batch_token_ids = np.asarray(output["token_ids"])[:, :-1]
batch_logprobs = output["logprob"][:, :-1]
batch_full_logprob = output["full_logprob"][:, :-1, :]
# Compute greedy tokens for entire batch rather than calling it with proper ctxlen for each sample.
# Additional tokens for each sample will be trimmed later.
min_ctxlen = min(ctxlens)
# Use min_ctxlen-1 instead of min_ctxlen since full_logprobs are not returns for the first token.
batch_greedy_tokens = (
torch.argmax(batch_full_logprob[:, min_ctxlen - 1 :, :], -1)
.cpu()
.numpy()
)
for token_ids, greedy_tokens, logprobs, ctxlen, contlen, (
cache_key,
_,
_,
) in zip(
batch_token_ids,
batch_greedy_tokens,
batch_logprobs,
ctxlens,
contlens,
chunk,
):
# Trim at contlen since shorter contexts in a batch will have more than one token generated.
# Use ctxlen-1 instead of ctxlen same as for full_logprob in batch_greedy_tokens calculation
logprobs = (logprobs[ctxlen - 1 :])[:contlen]
logprob = sum(logprobs).tolist()
continuation_tokens = (token_ids[ctxlen:])[:contlen]
len_diff = ctxlen - min_ctxlen
is_greedy = continuation_tokens == (greedy_tokens[len_diff:])[:contlen]
if not isinstance(is_greedy, bool):
is_greedy = is_greedy.all()
answer = (logprob, is_greedy)
if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
res.append(answer)
pbar.update(1)
pbar.close()
return re_ord.get_original(res)