in src/lighteval/models/endpoints/endpoint_model.py [0:0]
def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
dataset = LoglikelihoodDataset(requests=docs, num_dataset_splits=self.DATASET_SPLITS)
batch_size = self.config.batch_size
results = []
for split in tqdm(
dataset.splits_iterator(),
total=dataset.num_dataset_splits,
desc="Splits",
position=0,
disable=self.disable_tqdm,
):
dataloader = DataLoader(split, batch_size=batch_size, collate_fn=lambda batch: batch)
for batch in tqdm(dataloader, desc="Loglikelihoods", position=1, leave=False, disable=self.disable_tqdm):
if self.use_async:
responses = asyncio.run(self._async_process_batch_logprob(batch))
else:
responses = self._process_batch_logprob(batch)
for cur_request, response in zip(batch, responses):
cont_toks = torch.tensor(cur_request.tokenized_continuation)
len_choice = len(cont_toks)
if self.endpoint: # inference endpoint
logits = [
t.logprob for t in response.details.prefill[-len_choice:] if t.logprob is not None
] # to check
else: # serverless endpoint
logits = [t.logprob for t in response.details.tokens[-len_choice:] if t.logprob is not None]
greedy_tokens = torch.tensor(logits).argmax(dim=-1)
max_equal = (greedy_tokens == cont_toks).all().squeeze(0)
results.append(
ModelResponse(
logprobs=(sum(logits), bool(max_equal)),
input_tokens=[t.id for t in response.details.prefill[:-len_choice]],
output_tokens=[t.id for t in response.details.prefill[-len_choice:]],
truncated_tokens_count=-1,
padded_tokens_count=-1,
)
)
return dataset.get_original_order(results)