def warmup()

in text-generation-inference/server/text_generation_server/generator.py [0:0]


    def warmup(self, batch: Batch) -> int:
        """Verify if the hardware can support the target load.

        Args:
            batch (`Batch`):
                A batch corresponding to the maximum number of concurrent requests.

        Return:
            The maximum number of tokens the model supports.
        """
        logger.debug("Warming up the model")
        start = time.time()
        # Just check that the warmup request parameters match the model capacity
        # NOTE: later self.model.config.batch_size might become self.model.config.max_batch_size.
        if self.model.config.batch_size is not None:
            batch_size = self.model.config.batch_size
        else:
            # batch size is not set, just assume it's unlimited and accept all requests
            batch_size = len(batch.requests)
        if len(batch.requests) > batch_size:
            raise ValueError(
                f"Inconsistent server configuration: please make sure max-prefill-tokens does not exceed {batch_size} x max-input-length."
            )

        # Counter-intuitively, now we ignore the input batch. Instead, we create dummy batches to cover all possible
        # batch sizes and sequence lengths.
        seq_len = self.model.config.sequence_length
        if os.environ.get("SKIP_WARMUP", "0") == "1":
            logger.debug("Skipping warmup")
            return batch_size * seq_len
        bucket_seq_len = take_nearest_length(seq_len)
        requests = [self._create_dummy_request(seq_len) for _ in range(batch_size)]
        for _ in reversed(range(batch_size)):
            # Prefill with different truncate sizes to test all prefill lengths. List is reversed so first longest
            # sequences are tested and, if there is a memory failure, that will appear sooner.
            for l in reversed(PREFILL_LENGTHS):
                # Skip all the unsupported lengths
                if l > bucket_seq_len:
                    continue
                # Set all truncate values for all requests
                for r in requests:
                    r.truncate = l
                    r.stopping_parameters.max_new_tokens = 10
                warmup_batch = Batch(id=0,
                                     requests=requests,
                                     size=len(requests),
                                     max_tokens=batch.max_tokens)
                logger.debug(f"Warmup for {len(requests)} requests, truncate value {l} seq_len {seq_len}")
                _generations, next_batch = self.prefill(warmup_batch)
                if next_batch is not None:
                    self.decode([next_batch])
                else:
                    logger.debug(f"No decode on warmup for {len(requests)}x{l}")
                self.clear()
            # remove the last requests to decrease the batch size
            requests.pop()

        elapsed = time.time() - start
        logger.debug(f"Warmup done, took {elapsed:.2f}s")
        return batch_size * seq_len