in src/lighteval/models/endpoints/endpoint_model.py [0:0]
def loglikelihood_rolling(self, requests: list[Doc], override_bs=None) -> list[ModelResponse]:
"""This function is used to compute the log likelihood of the context for perplexity metrics."""
dataset = LoglikelihoodDataset(requests=requests, num_dataset_splits=self.DATASET_SPLITS)
batch_size = override_bs if override_bs is not None else BATCH_SIZE
results: list[ModelResponse] = []
for split in tqdm(
dataset.splits_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=self.disable_tqdm,
):
dataloader = DataLoader(split, batch_size=batch_size, collate_fn=lambda batch: batch)
for batch in tqdm(
dataloader, desc="Loglikelihoods, rolling", position=1, leave=False, disable=self.disable_tqdm
):
if self.use_async:
responses = asyncio.run(self._async_process_batch_logprob(batch, rolling=True))
else:
responses = self._process_batch_logprob(batch, rolling=True)
for response in responses:
logits = [t.logprob for t in response.details.tokens[:-1]]
results.append(
ModelResponse(
result=sum(logits),
input_tokens=[t.id for t in response.details.prefill],
generated_tokens=[t.id for t in response.details.tokens[:-1]],
truncated_tokens_count=-1,
padded_tokens_count=-1,
)
)
return dataset.get_original_order(results)