in lm_eval/models/huggingface.py [0:0]
def _detect_batch_size(self, requests=None, pos: int = 0) -> int:
if len(requests[0]) == 3: # logprob evals
_, context_enc, continuation_enc = requests[pos]
max_length = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
)
max_context_enc = len(context_enc[-(self.max_length + 1) :])
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
security_margin_factor = (
4 # batch sizes for log prob evals sometimes generate OOMs
)
elif len(requests[0]) == 2: # generative evals
# using rolling window with maximum context
longest_context = max(
[
len(self.tok_encode(request[0]))
+ request[1].get("max_gen_toks", self.max_length)
for request in requests[pos:]
]
)
if longest_context > self.max_length:
eval_logger.warning(
f"Longest context length of {longest_context} exceeds max_length of {self.max_length}. Truncating to max_length."
)
longest_context = self.max_length
max_length = longest_context
max_context_enc = max_length
max_cont_enc = max_length
security_margin_factor = 4
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
security_margin = int(0.05 * security_margin_factor * batch_size)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones(
(batch_size, length), device=self.device
).long()
test_batch = torch.ones(
(batch_size + security_margin, length), device=self.device
).long()
call_kwargs = {
"attn_mask": test_batch,
"labels": batched_conts,
}
else:
call_kwargs = {}
test_batch = torch.ones(
(batch_size, max_length), device=self.device
).long()
for _ in range(5 * security_margin_factor):
logits = self._model_call(inps=test_batch, **call_kwargs).float()
scores = F.log_softmax(logits, dim=-1) # noqa: F841
return batch_size
try:
batch_size = forward_batch()
except RuntimeError as e:
if "No executable batch size found" in str(e):
batch_size = 1
else:
raise
if self.world_size > 1:
# if multi-GPU, always take minimum over all selected batch sizes
max_rnk_bs = torch.tensor([batch_size], device=self.device)
gathered = (
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
)
batch_size = min(gathered)
clear_torch_cache()
return batch_size
clear_torch_cache()
return batch_size