def _detect_batch_size()

in lm_eval/models/huggingface.py [0:0]


    def _detect_batch_size(self, requests=None, pos: int = 0) -> int:
        if len(requests[0]) == 3:  # logprob evals
            _, context_enc, continuation_enc = requests[pos]
            max_length = len(
                (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
            )
            max_context_enc = len(context_enc[-(self.max_length + 1) :])
            max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
            security_margin_factor = (
                4  # batch sizes for log prob evals sometimes generate OOMs
            )
        elif len(requests[0]) == 2:  # generative evals
            # using rolling window with maximum context
            longest_context = max(
                [
                    len(self.tok_encode(request[0]))
                    + request[1].get("max_gen_toks", self.max_length)
                    for request in requests[pos:]
                ]
            )
            if longest_context > self.max_length:
                eval_logger.warning(
                    f"Longest context length of {longest_context} exceeds max_length of {self.max_length}. Truncating to max_length."
                )
                longest_context = self.max_length
            max_length = longest_context
            max_context_enc = max_length
            max_cont_enc = max_length
            security_margin_factor = 4

        # if OOM, then halves batch_size and tries again
        @find_executable_batch_size(starting_batch_size=self.max_batch_size)
        def forward_batch(batch_size):
            security_margin = int(0.05 * security_margin_factor * batch_size)
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                length = max(max_context_enc, max_cont_enc)
                batched_conts = torch.ones(
                    (batch_size, length), device=self.device
                ).long()
                test_batch = torch.ones(
                    (batch_size + security_margin, length), device=self.device
                ).long()
                call_kwargs = {
                    "attn_mask": test_batch,
                    "labels": batched_conts,
                }
            else:
                call_kwargs = {}
                test_batch = torch.ones(
                    (batch_size, max_length), device=self.device
                ).long()
            for _ in range(5 * security_margin_factor):
                logits = self._model_call(inps=test_batch, **call_kwargs).float()
                scores = F.log_softmax(logits, dim=-1)  # noqa: F841

            return batch_size

        try:
            batch_size = forward_batch()
        except RuntimeError as e:
            if "No executable batch size found" in str(e):
                batch_size = 1
            else:
                raise

        if self.world_size > 1:
            # if multi-GPU, always take minimum over all selected batch sizes
            max_rnk_bs = torch.tensor([batch_size], device=self.device)
            gathered = (
                self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
            )
            batch_size = min(gathered)
            clear_torch_cache()
            return batch_size

        clear_torch_cache()
        return batch_size